From 821c0161079acb11e36ae7d5ac03b1a35d4aa205 Mon Sep 17 00:00:00 2001 From: Giles Billenness Date: Wed, 16 Mar 2022 16:28:55 +0000 Subject: [PATCH 1/2] Change load pretrained to accept models from MoBY Added a check to remove encoder prefixes as done in https://github.com/SwinTransformer/Swin-Transformer-Object-Detection Added a head check to reinit if the head key isn't present --- utils.py | 45 +++++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/utils.py b/utils.py index 94890ca46..4e60700ac 100644 --- a/utils.py +++ b/utils.py @@ -48,6 +48,10 @@ def load_pretrained(config, model, logger): checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') state_dict = checkpoint['model'] + # for MoBY, preplace prefixes + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + # delete relative_position_index since we always re-init it relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] for k in relative_position_index_keys: @@ -105,24 +109,29 @@ def load_pretrained(config, model, logger): state_dict[k] = absolute_pos_embed_pretrained_resized # check classifier, if not match, then re-init classifier to zero - head_bias_pretrained = state_dict['head.bias'] - Nc1 = head_bias_pretrained.shape[0] - Nc2 = model.head.bias.shape[0] - if (Nc1 != Nc2): - if Nc1 == 21841 and Nc2 == 1000: - logger.info("loading ImageNet-22K weight to ImageNet-1K ......") - map22kto1k_path = f'data/map22kto1k.txt' - with open(map22kto1k_path) as f: - map22kto1k = f.readlines() - map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] - state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] - state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] - else: - torch.nn.init.constant_(model.head.bias, 0.) - torch.nn.init.constant_(model.head.weight, 0.) - del state_dict['head.weight'] - del state_dict['head.bias'] - logger.warning(f"Error in loading classifier head, re-init classifier head to 0") + if ('head.bias' in state_dict): + head_bias_pretrained = state_dict['head.bias'] + Nc1 = head_bias_pretrained.shape[0] + Nc2 = model.head.bias.shape[0] + if (Nc1 != Nc2): + if Nc1 == 21841 and Nc2 == 1000: + logger.info("loading ImageNet-22K weight to ImageNet-1K ......") + map22kto1k_path = f'data/map22kto1k.txt' + with open(map22kto1k_path) as f: + map22kto1k = f.readlines() + map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] + state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] + state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] + else: + torch.nn.init.constant_(model.head.bias, 0.) + torch.nn.init.constant_(model.head.weight, 0.) + del state_dict['head.weight'] + del state_dict['head.bias'] + logger.warning(f"Error in loading classifier head, re-init classifier head to 0") + else: + torch.nn.init.constant_(model.head.bias, 0.) + torch.nn.init.constant_(model.head.weight, 0.) + logger.warning(f"Error in loading classifier head, re-init classifier head to 0") msg = model.load_state_dict(state_dict, strict=False) logger.warning(msg) From 5e6096a69d3ae58421ee700f30382ad8f781afae Mon Sep 17 00:00:00 2001 From: Giles Billenness Date: Tue, 21 Feb 2023 20:18:27 +0000 Subject: [PATCH 2/2] Make like original --- README.md | 2 ++ utils.py | 1 + 2 files changed, 3 insertions(+) diff --git a/README.md b/README.md index 64db07ef8..7770a68e8 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,8 @@ Note: * indicates multi-scale testing. (`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`) +[04/06/2022] Swin Transformer for Audio Classification: [Hierarchical Token Semantic Audio Transformer](https://github.com/RetroCirce/HTS-Audio-Transformer). + [12/21/2021] Swin Transformer for StyleGAN: [StyleSwin](https://github.com/microsoft/StyleSwin) [12/13/2021] Swin Transformer for Face Recognition: [FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo) diff --git a/utils.py b/utils.py index 4e60700ac..472107bda 100644 --- a/utils.py +++ b/utils.py @@ -48,6 +48,7 @@ def load_pretrained(config, model, logger): checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') state_dict = checkpoint['model'] + # for MoBY, preplace prefixes if sorted(list(state_dict.keys()))[0].startswith('encoder'): state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}