MCPcopy
hub / github.com/microsoft/Swin-Transformer / load_pretrained

Function load_pretrained

utils.py:45–132  ·  view source on GitHub ↗
(config, model, logger)

Source from the content-addressed store, hash-verified

43
44
45def load_pretrained(config, model, logger):
46 logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
47 checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
48 state_dict = checkpoint['model']
49
50 # delete relative_position_index since we always re-init it
51 relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
52 for k in relative_position_index_keys:
53 del state_dict[k]
54
55 # delete relative_coords_table since we always re-init it
56 relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
57 for k in relative_position_index_keys:
58 del state_dict[k]
59
60 # delete attn_mask since we always re-init it
61 attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
62 for k in attn_mask_keys:
63 del state_dict[k]
64
65 # bicubic interpolate relative_position_bias_table if not match
66 relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
67 for k in relative_position_bias_table_keys:
68 relative_position_bias_table_pretrained = state_dict[k]
69 relative_position_bias_table_current = model.state_dict()[k]
70 L1, nH1 = relative_position_bias_table_pretrained.size()
71 L2, nH2 = relative_position_bias_table_current.size()
72 if nH1 != nH2:
73 logger.warning(f"Error in loading {k}, passing......")
74 else:
75 if L1 != L2:
76 # bicubic interpolate relative_position_bias_table if not match
77 S1 = int(L1 ** 0.5)
78 S2 = int(L2 ** 0.5)
79 relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
80 relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
81 mode='bicubic')
82 state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
83
84 # bicubic interpolate absolute_pos_embed if not match
85 absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
86 for k in absolute_pos_embed_keys:
87 # dpe
88 absolute_pos_embed_pretrained = state_dict[k]
89 absolute_pos_embed_current = model.state_dict()[k]
90 _, L1, C1 = absolute_pos_embed_pretrained.size()
91 _, L2, C2 = absolute_pos_embed_current.size()
92 if C1 != C1:
93 logger.warning(f"Error in loading {k}, passing......")
94 else:
95 if L1 != L2:
96 S1 = int(L1 ** 0.5)
97 S2 = int(L2 ** 0.5)
98 absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
99 absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
100 absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
101 absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
102 absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)

Callers 1

mainFunction · 0.90

Calls 2

state_dictMethod · 0.80
load_state_dictMethod · 0.80

Tested by

no test coverage detected