Exploring Timm vs Torchvision Resnet18 Difference
Reviewing the impact batchnorm initialization has on non-pretrained model performance
- Timm Vanilla
- TorchVision Vanilla
- What is causing this difference?
- Timm zero_init_last_bn=False
- TorchVision zero_init_residual=True
- Conclusion
- Next Steps
I was recently doing some experiments on imagenette testing out the new PolyLoss paper and I noticed that when I was running my baseline model using resnet18 from torchvision, I was consistently getting ~78% after 5 epochs, but that same baseline model was around 72% consistently when I used resnet18 from timm instead. For this post, I'm going to stick to one run per model, but this really should use at least 5 runs to make sure the issue isn't a poorly seeded run.
The first thing I am going to do is import fastai's vision module and download imagenette which is a dataset to test techniques that is lighter than imagenet.
from fastai.vision.all import *
imagenette_url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
data_path = untar_data(imagenette_url, archive='/data/archive', data='/data/data')
Now, let's train a model using timm's resnet18 architecture. The newest version of fastai makes this really slick by allowing a user to pass resnet18 in as a string. This is a signal to fastai's vision_learner to look for this in the timm model library.
dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_timm = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy)
learn_timm.fit_one_cycle(5)
72.7% accuracy when we run timm's resnet18 for 5 epochs.
Now let's do the same training but instead, let's use torchvision's resnet18 which can be called by passing resnet18 to the vision_learner (not the string version).
dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_torchvision = vision_learner(dls, resnet18, pretrained=False, metrics=accuracy)
learn_torchvision.fit_one_cycle(5)
78.2% accuracy when we run torchvision's resnet18 for 5 epochs.
This difference was pointed out in the fastai discord channel and Ross Wightman, the creator of timm had some ideas. The first was to try running this experiment multiple times due to variance. This was easy enough to test so I went ahead and saw a similar pattern for the next 5 runs. The next thing he mentioned was something called zero_init which I hadn't heard of before. The argument may be referred to as zero_init_residual
or zero_init_last_bn
. The timm library defaults this variable to True and torchvision defaults this to False. First, let's confirm that this difference fixes our discrepancy between timm and torchvision, then I'll explain what it is doing, and lastly I will explain which is the better option.
dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_timm_no_zero = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy, zero_init_last_bn=False)
learn_timm_no_zero.fit_one_cycle(5)
dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_tv_zero_bn = vision_learner(dls, partial(resnet18, zero_init_residual=True), pretrained=False, metrics=accuracy)
learn_tv_zero_bn.fit_one_cycle(5)
So what is this option that is swinging our accuracy by 5%? It is an option that says whether we should start the second batchnorm layer (bn2) of our resnet model at 0 or at 1.
dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_timm = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy)
learn_timm.model[0].model.layer1
learn_timm.model[0].model.layer1[0].bn2.weight
dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_torchvision = vision_learner(dls, resnet18, pretrained=False, metrics=accuracy)
learn_torchvision.model[0][4]
learn_torchvision.model[0][4][0].bn2.weight
The other thing that may catch your eye here is that timm has a second relu, but this is actually just a difference in torchvision using the same relu twice in the forward function so it doesn't quite show the full picture.
After changing these two defaults, I am now able to see similar performance where TorchVision performs at a lower accuracy level and timm performs at a higher accuracy level. So clearly, setting this option to False is best right? Not so fast. Ross says that while this option will perform better on short epoch runs when this option is set to False, it is not the case on a longer training run and actually will outperform the non-zero out version.
The next thing to do is to test the claim that the zero'd version performs better and also to try other initializations as well. This is also not an issue if using pretrained weights since the bn2 weights will be specified by the pretrained weights already so this is only something that will occur if new performance metrics are being compared as was the case when this question arose.