from fastai.vision.all import *
= 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
imagenette_url = untar_data(imagenette_url, archive='/data/archive', data='/data/data') data_path
Exploring Timm vs Torchvision Resnet18 Difference
I was recently doing some experiments on imagenette testing out the new PolyLoss paper and I noticed that when I was running my baseline model using resnet18 from torchvision, I was consistently getting ~78% after 5 epochs, but that same baseline model was around 72% consistently when I used resnet18 from timm instead. For this post, I’m going to stick to one run per model, but this really should use at least 5 runs to make sure the issue isn’t a poorly seeded run.
The first thing I am going to do is import fastai’s vision module and download imagenette which is a dataset to test techniques that is lighter than imagenet.
Timm Vanilla
Now, let’s train a model using timm’s resnet18 architecture. The newest version of fastai makes this really slick by allowing a user to pass resnet18 in as a string. This is a signal to fastai’s vision_learner to look for this in the timm model library.
= ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
dls = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy)
learn_timm 5) learn_timm.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.412845 | 2.793077 | 0.292484 | 00:25 |
1 | 1.847957 | 2.888684 | 0.314904 | 00:24 |
2 | 1.413061 | 1.247371 | 0.608662 | 00:24 |
3 | 1.124863 | 1.019361 | 0.675414 | 00:24 |
4 | 0.963823 | 0.848421 | 0.726879 | 00:24 |
72.7% accuracy when we run timm’s resnet18 for 5 epochs.
TorchVision Vanilla
Now let’s do the same training but instead, let’s use torchvision’s resnet18 which can be called by passing resnet18 to the vision_learner (not the string version).
= ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
dls = vision_learner(dls, resnet18, pretrained=False, metrics=accuracy)
learn_torchvision 5) learn_torchvision.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.369731 | 3.082903 | 0.196178 | 00:24 |
1 | 1.618597 | 2.213618 | 0.449936 | 00:24 |
2 | 1.195862 | 2.348811 | 0.519236 | 00:24 |
3 | 0.949387 | 0.815663 | 0.735541 | 00:25 |
4 | 0.721939 | 0.674278 | 0.781656 | 00:24 |
78.2% accuracy when we run torchvision’s resnet18 for 5 epochs.
What is causing this difference?
This difference was pointed out in the fastai discord channel and Ross Wightman, the creator of timm had some ideas. The first was to try running this experiment multiple times due to variance. This was easy enough to test so I went ahead and saw a similar pattern for the next 5 runs. The next thing he mentioned was something called zero_init which I hadn’t heard of before. The argument may be referred to as zero_init_residual
or zero_init_last_bn
. The timm library defaults this variable to True and torchvision defaults this to False. First, let’s confirm that this difference fixes our discrepancy between timm and torchvision, then I’ll explain what it is doing, and lastly I will explain which is the better option.
Timm zero_init_last_bn=False
= ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
dls = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy, zero_init_last_bn=False)
learn_timm_no_zero
5) learn_timm_no_zero.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.349849 | 2.469171 | 0.282803 | 00:24 |
1 | 1.621092 | 3.614518 | 0.263949 | 00:24 |
2 | 1.222227 | 1.627461 | 0.538599 | 00:24 |
3 | 0.960997 | 0.816477 | 0.741401 | 00:24 |
4 | 0.718745 | 0.682136 | 0.787516 | 00:24 |
TorchVision zero_init_residual=True
= ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
dls = vision_learner(dls, partial(resnet18, zero_init_residual=True), pretrained=False, metrics=accuracy)
learn_tv_zero_bn 5) learn_tv_zero_bn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.453036 | 3.386052 | 0.208408 | 00:24 |
1 | 1.851731 | 2.367089 | 0.373758 | 00:24 |
2 | 1.414447 | 1.292271 | 0.581401 | 00:24 |
3 | 1.109155 | 0.982513 | 0.684331 | 00:24 |
4 | 0.927408 | 0.845072 | 0.729682 | 00:24 |
So what is this option that is swinging our accuracy by 5%? It is an option that says whether we should start the second batchnorm layer (bn2) of our resnet model at 0 or at 1.
= ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
dls = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy) learn_timm
0].model.layer1 learn_timm.model[
Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act1): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act2): ReLU(inplace=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act1): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act2): ReLU(inplace=True)
)
)
0].model.layer1[0].bn2.weight learn_timm.model[
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
requires_grad=True)
= ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
dls = vision_learner(dls, resnet18, pretrained=False, metrics=accuracy) learn_torchvision
0][4] learn_torchvision.model[
Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
0][4][0].bn2.weight learn_torchvision.model[
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)
The other thing that may catch your eye here is that timm has a second relu, but this is actually just a difference in torchvision using the same relu twice in the forward function so it doesn’t quite show the full picture.
Conclusion
After changing these two defaults, I am now able to see similar performance where TorchVision performs at a lower accuracy level and timm performs at a higher accuracy level. So clearly, setting this option to False is best right? Not so fast. Ross says that while this option will perform better on short epoch runs when this option is set to False, it is not the case on a longer training run and actually will outperform the non-zero out version.
Next Steps
The next thing to do is to test the claim that the zero’d version performs better and also to try other initializations as well. This is also not an issue if using pretrained weights since the bn2 weights will be specified by the pretrained weights already so this is only something that will occur if new performance metrics are being compared as was the case when this question arose.