Exploring Timm vs Torchvision Resnet18 Difference

fastai
technical
exploration
Reviewing the impact batchnorm initialization has on non-pretrained model performance
Author

Kevin Bird

Published

May 2, 2022

I was recently doing some experiments on imagenette testing out the new PolyLoss paper and I noticed that when I was running my baseline model using resnet18 from torchvision, I was consistently getting ~78% after 5 epochs, but that same baseline model was around 72% consistently when I used resnet18 from timm instead. For this post, I’m going to stick to one run per model, but this really should use at least 5 runs to make sure the issue isn’t a poorly seeded run.

The first thing I am going to do is import fastai’s vision module and download imagenette which is a dataset to test techniques that is lighter than imagenet.

from fastai.vision.all import *
imagenette_url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz'
data_path = untar_data(imagenette_url, archive='/data/archive', data='/data/data')

Timm Vanilla

Now, let’s train a model using timm’s resnet18 architecture. The newest version of fastai makes this really slick by allowing a user to pass resnet18 in as a string. This is a signal to fastai’s vision_learner to look for this in the timm model library.

dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_timm = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy)
learn_timm.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 2.412845 2.793077 0.292484 00:25
1 1.847957 2.888684 0.314904 00:24
2 1.413061 1.247371 0.608662 00:24
3 1.124863 1.019361 0.675414 00:24
4 0.963823 0.848421 0.726879 00:24

72.7% accuracy when we run timm’s resnet18 for 5 epochs.

TorchVision Vanilla

Now let’s do the same training but instead, let’s use torchvision’s resnet18 which can be called by passing resnet18 to the vision_learner (not the string version).

dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_torchvision = vision_learner(dls, resnet18, pretrained=False, metrics=accuracy)
learn_torchvision.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 2.369731 3.082903 0.196178 00:24
1 1.618597 2.213618 0.449936 00:24
2 1.195862 2.348811 0.519236 00:24
3 0.949387 0.815663 0.735541 00:25
4 0.721939 0.674278 0.781656 00:24

78.2% accuracy when we run torchvision’s resnet18 for 5 epochs.

What is causing this difference?

This difference was pointed out in the fastai discord channel and Ross Wightman, the creator of timm had some ideas. The first was to try running this experiment multiple times due to variance. This was easy enough to test so I went ahead and saw a similar pattern for the next 5 runs. The next thing he mentioned was something called zero_init which I hadn’t heard of before. The argument may be referred to as zero_init_residual or zero_init_last_bn. The timm library defaults this variable to True and torchvision defaults this to False. First, let’s confirm that this difference fixes our discrepancy between timm and torchvision, then I’ll explain what it is doing, and lastly I will explain which is the better option.

Timm zero_init_last_bn=False

dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_timm_no_zero = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy, zero_init_last_bn=False)

learn_timm_no_zero.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 2.349849 2.469171 0.282803 00:24
1 1.621092 3.614518 0.263949 00:24
2 1.222227 1.627461 0.538599 00:24
3 0.960997 0.816477 0.741401 00:24
4 0.718745 0.682136 0.787516 00:24

TorchVision zero_init_residual=True

dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_tv_zero_bn = vision_learner(dls, partial(resnet18, zero_init_residual=True), pretrained=False, metrics=accuracy)
learn_tv_zero_bn.fit_one_cycle(5)
epoch train_loss valid_loss accuracy time
0 2.453036 3.386052 0.208408 00:24
1 1.851731 2.367089 0.373758 00:24
2 1.414447 1.292271 0.581401 00:24
3 1.109155 0.982513 0.684331 00:24
4 0.927408 0.845072 0.729682 00:24

So what is this option that is swinging our accuracy by 5%? It is an option that says whether we should start the second batchnorm layer (bn2) of our resnet model at 0 or at 1.

dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_timm = vision_learner(dls, 'resnet18', pretrained=False, metrics=accuracy)
learn_timm.model[0].model.layer1
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): ReLU(inplace=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act2): ReLU(inplace=True)
  )
)
learn_timm.model[0].model.layer1[0].bn2.weight
Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       requires_grad=True)
dls = ImageDataLoaders.from_folder(data_path, valid='val', item_tfms=Resize(256))
learn_torchvision = vision_learner(dls, resnet18, pretrained=False, metrics=accuracy)
learn_torchvision.model[0][4]
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
learn_torchvision.model[0][4][0].bn2.weight
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)

The other thing that may catch your eye here is that timm has a second relu, but this is actually just a difference in torchvision using the same relu twice in the forward function so it doesn’t quite show the full picture.

Conclusion

After changing these two defaults, I am now able to see similar performance where TorchVision performs at a lower accuracy level and timm performs at a higher accuracy level. So clearly, setting this option to False is best right? Not so fast. Ross says that while this option will perform better on short epoch runs when this option is set to False, it is not the case on a longer training run and actually will outperform the non-zero out version.

Next Steps

The next thing to do is to test the claim that the zero’d version performs better and also to try other initializations as well. This is also not an issue if using pretrained weights since the bn2 weights will be specified by the pretrained weights already so this is only something that will occur if new performance metrics are being compared as was the case when this question arose.