Skip to content

Loading Pretrained Pytorch Inception_V3 Weights #16

@asy51

Description

@asy51

I am getting state_dict mismatches:

class Backbone(nn.Module):
    def __init__(self, net='inceptionv3'):
        super().__init__()
        if net == 'inceptionv3':
            base_model = inception_v3()
        elif net == 'densenet121':
            base_model = densenet121()
        elif net == 'resnet50':
            base_model = resnet50()
        encoder_layers = list(base_model.children())
        self.backbone = nn.Sequential(*encoder_layers[:-1])
                        
    def forward(self, x):
        return self.backbone(x)

net = 'inceptionv3'
backbone = Backbone(net)
backbone.load_state_dict(torch.load(RAD[net]))

Error msg:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[21], [line 18](vscode-notebook-cell:?execution_count=21&line=18)
     [16](vscode-notebook-cell:?execution_count=21&line=16) net = 'inceptionv3'
     [17](vscode-notebook-cell:?execution_count=21&line=17) backbone = Backbone(net)
---> [18](vscode-notebook-cell:?execution_count=21&line=18) backbone.load_state_dict(torch.load(RAD[net]))

...

RuntimeError: Error(s) in loading state_dict for Backbone:
	Missing key(s) in state_dict: "backbone.15.conv0.conv.weight", "backbone.15.conv0.bn.weight", "backbone.15.conv0.bn.bias", "backbone.15.conv0.bn.running_mean", "backbone.15.conv0.bn.running_var", "backbone.15.conv1.conv.weight", "backbone.15.conv1.bn.weight", "backbone.15.conv1.bn.bias", "backbone.15.conv1.bn.running_mean", "backbone.15.conv1.bn.running_var", "backbone.15.fc.weight", "backbone.15.fc.bias", "backbone.16.branch3x3_2.conv.weight", "backbone.16.branch3x3_2.bn.weight", "backbone.16.branch3x3_2.bn.bias", "backbone.16.branch3x3_2.bn.running_mean", "backbone.16.branch3x3_2.bn.running_var", "backbone.16.branch7x7x3_1.conv.weight", "backbone.16.branch7x7x3_1.bn.weight", "backbone.16.branch7x7x3_1.bn.bias", "backbone.16.branch7x7x3_1.bn.running_mean", "backbone.16.branch7x7x3_1.bn.running_var", "backbone.16.branch7x7x3_2.conv.weight", "backbone.16.branch7x7x3_2.bn.weight", "backbone.16.branch7x7x3_2.bn.bias", "backbone.16.branch7x7x3_2.bn.running_mean", "backbone.16.branch7x7x3_2.bn.running_var", "backbone.16.branch7x7x3_3.conv.weight", "backbone.16.branch7x7x3_3.bn.weight", "backbone.16.branch7x7x3_3.bn.bias", "backbone.16.branch7x7x3_3.bn.running_mean", "backbone.16.branch7x7x3_3.bn.running_var", "backbone.16.branch7x7x3_4.conv.weight", "backbone.16.branch7x7x3_4.bn.weight", "backbone.16.branch7x7x3_4.bn.bias", "backbone.16.branch7x7x3_4.bn.running_mean", "backbone.16.branch7x7x3_4.bn.running_var", "backbone.18.branch1x1.conv.weight", "backbone.18.branch1x1.bn.weight", "backbone.18.branch1x1.bn.bias", "backbone.18.branch1x1.bn.running_mean", "backbone.18.branch1x1.bn.running_var", "backbone.18.branch3x3_1.conv.weight", "backbone.18.branch3x3_1.bn.weight", "backbone.18.branch3x3_1.bn.bias", "backbone.18.branch3x3_1.bn.running_mean", "backbone.18.branch3x3_1.bn.running_var", "backbone.18.branch3x3_2a.conv.weight", "backbone.18.branch3x3_2a.bn.weight", "backbone.18.branch3x3_2a.bn.bias", "backbone.18.branch3x3_2a.bn.running_mean", "backbone.18.branch3x3_2a.bn.running_var", "backbone.18.branch3x3_2b.conv.weight", "backbone.18.branch3x3_2b.bn.weight", "backbone.18.branch3x3_2b.bn.bias", "backbone.18.branch3x3_2b.bn.running_mean", "backbone.18.branch3x3_2b.bn.running_var", "backbone.18.branch3x3dbl_1.conv.weight", "backbone.18.branch3x3dbl_1.bn.weight", "backbone.18.branch3x3dbl_1.bn.bias", "backbone.18.branch3x3dbl_1.bn.running_mean", "backbone.18.branch3x3dbl_1.bn.running_var", "backbone.18.branch3x3dbl_2.conv.weight", "backbone.18.branch3x3dbl_2.bn.weight", "backbone.18.branch3x3dbl_2.bn.bias", "backbone.18.branch3x3dbl_2.bn.running_mean", "backbone.18.branch3x3dbl_2.bn.running_var", "backbone.18.branch3x3dbl_3a.conv.weight", "backbone.18.branch3x3dbl_3a.bn.weight", "backbone.18.branch3x3dbl_3a.bn.bias", "backbone.18.branch3x3dbl_3a.bn.running_mean", "backbone.18.branch3x3dbl_3a.bn.running_var", "backbone.18.branch3x3dbl_3b.conv.weight", "backbone.18.branch3x3dbl_3b.bn.weight", "backbone.18.branch3x3dbl_3b.bn.bias", "backbone.18.branch3x3dbl_3b.bn.running_mean", "backbone.18.branch3x3dbl_3b.bn.running_var", "backbone.18.branch_pool.conv.weight", "backbone.18.branch_pool.bn.weight", "backbone.18.branch_pool.bn.bias", "backbone.18.branch_pool.bn.running_mean", "backbone.18.branch_pool.bn.running_var". 
	Unexpected key(s) in state_dict: "backbone.15.branch3x3_1.conv.weight", "backbone.15.branch3x3_1.bn.weight", "backbone.15.branch3x3_1.bn.bias", "backbone.15.branch3x3_1.bn.running_mean", "backbone.15.branch3x3_1.bn.running_var", "backbone.15.branch3x3_1.bn.num_batches_tracked", "backbone.15.branch3x3_2.conv.weight", "backbone.15.branch3x3_2.bn.weight", "backbone.15.branch3x3_2.bn.bias", "backbone.15.branch3x3_2.bn.running_mean", "backbone.15.branch3x3_2.bn.running_var", "backbone.15.branch3x3_2.bn.num_batches_tracked", "backbone.15.branch7x7x3_1.conv.weight", "backbone.15.branch7x7x3_1.bn.weight", "backbone.15.branch7x7x3_1.bn.bias", "backbone.15.branch7x7x3_1.bn.running_mean", "backbone.15.branch7x7x3_1.bn.running_var", "backbone.15.branch7x7x3_1.bn.num_batches_tracked", "backbone.15.branch7x7x3_2.conv.weight", "backbone.15.branch7x7x3_2.bn.weight", "backbone.15.branch7x7x3_2.bn.bias", "backbone.15.branch7x7x3_2.bn.running_mean", "backbone.15.branch7x7x3_2.bn.running_var", "backbone.15.branch7x7x3_2.bn.num_batches_tracked", "backbone.15.branch7x7x3_3.conv.weight", "backbone.15.branch7x7x3_3.bn.weight", "backbone.15.branch7x7x3_3.bn.bias", "backbone.15.branch7x7x3_3.bn.running_mean", "backbone.15.branch7x7x3_3.bn.running_var", "backbone.15.branch7x7x3_3.bn.num_batches_tracked", "backbone.15.branch7x7x3_4.conv.weight", "backbone.15.branch7x7x3_4.bn.weight", "backbone.15.branch7x7x3_4.bn.bias", "backbone.15.branch7x7x3_4.bn.running_mean", "backbone.15.branch7x7x3_4.bn.running_var", "backbone.15.branch7x7x3_4.bn.num_batches_tracked", "backbone.16.branch1x1.conv.weight", "backbone.16.branch1x1.bn.weight", "backbone.16.branch1x1.bn.bias", "backbone.16.branch1x1.bn.running_mean", "backbone.16.branch1x1.bn.running_var", "backbone.16.branch1x1.bn.num_batches_tracked", "backbone.16.branch3x3_2a.conv.weight", "backbone.16.branch3x3_2a.bn.weight", "backbone.16.branch3x3_2a.bn.bias", "backbone.16.branch3x3_2a.bn.running_mean", "backbone.16.branch3x3_2a.bn.running_var", "backbone.16.branch3x3_2a.bn.num_batches_tracked", "backbone.16.branch3x3_2b.conv.weight", "backbone.16.branch3x3_2b.bn.weight", "backbone.16.branch3x3_2b.bn.bias", "backbone.16.branch3x3_2b.bn.running_mean", "backbone.16.branch3x3_2b.bn.running_var", "backbone.16.branch3x3_2b.bn.num_batches_tracked", "backbone.16.branch3x3dbl_1.conv.weight", "backbone.16.branch3x3dbl_1.bn.weight", "backbone.16.branch3x3dbl_1.bn.bias", "backbone.16.branch3x3dbl_1.bn.running_mean", "backbone.16.branch3x3dbl_1.bn.running_var", "backbone.16.branch3x3dbl_1.bn.num_batches_tracked", "backbone.16.branch3x3dbl_2.conv.weight", "backbone.16.branch3x3dbl_2.bn.weight", "backbone.16.branch3x3dbl_2.bn.bias", "backbone.16.branch3x3dbl_2.bn.running_mean", "backbone.16.branch3x3dbl_2.bn.running_var", "backbone.16.branch3x3dbl_2.bn.num_batches_tracked", "backbone.16.branch3x3dbl_3a.conv.weight", "backbone.16.branch3x3dbl_3a.bn.weight", "backbone.16.branch3x3dbl_3a.bn.bias", "backbone.16.branch3x3dbl_3a.bn.running_mean", "backbone.16.branch3x3dbl_3a.bn.running_var", "backbone.16.branch3x3dbl_3a.bn.num_batches_tracked", "backbone.16.branch3x3dbl_3b.conv.weight", "backbone.16.branch3x3dbl_3b.bn.weight", "backbone.16.branch3x3dbl_3b.bn.bias", "backbone.16.branch3x3dbl_3b.bn.running_mean", "backbone.16.branch3x3dbl_3b.bn.running_var", "backbone.16.branch3x3dbl_3b.bn.num_batches_tracked", "backbone.16.branch_pool.conv.weight", "backbone.16.branch_pool.bn.weight", "backbone.16.branch_pool.bn.bias", "backbone.16.branch_pool.bn.running_mean", "backbone.16.branch_pool.bn.running_var", "backbone.16.branch_pool.bn.num_batches_tracked". 
	size mismatch for backbone.16.branch3x3_1.conv.weight: copying a param with shape torch.Size([384, 1280, 1, 1]) from checkpoint, the shape in current model is torch.Size([192, 768, 1, 1]).
	size mismatch for backbone.16.branch3x3_1.bn.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for backbone.16.branch3x3_1.bn.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for backbone.16.branch3x3_1.bn.running_mean: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for backbone.16.branch3x3_1.bn.running_var: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for backbone.17.branch1x1.conv.weight: copying a param with shape torch.Size([320, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 1280, 1, 1]).
	size mismatch for backbone.17.branch3x3_1.conv.weight: copying a param with shape torch.Size([384, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([384, 1280, 1, 1]).
	size mismatch for backbone.17.branch3x3dbl_1.conv.weight: copying a param with shape torch.Size([448, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([448, 1280, 1, 1]).
	size mismatch for backbone.17.branch_pool.conv.weight: copying a param with shape torch.Size([192, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([192, 1280, 1, 1]).

torch and torchvision versions: ('2.0.0+cu117', '0.15.1+cu117')
Doesn't seem like conda environment.yaml or pip requirements.txt files are available
Please advise on how to load the weights! 🙏

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions