MCPcopy
hub / github.com/microsoft/Cream / load_pretrained

Function load_pretrained

TinyViT/utils.py:117–242  ·  view source on GitHub ↗
(config, model, logger)

Source from the content-addressed store, hash-verified

115
116
117def load_pretrained(config, model, logger):
118 logger.info(
119 f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
120 checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
121 state_dict = checkpoint['model']
122
123 # delete relative_position_index since we always re-init it
124 relative_position_index_keys = [
125 k for k in state_dict.keys() if "relative_position_index" in k]
126 for k in relative_position_index_keys:
127 del state_dict[k]
128
129 # delete relative_coords_table since we always re-init it
130 relative_position_index_keys = [
131 k for k in state_dict.keys() if "relative_coords_table" in k]
132 for k in relative_position_index_keys:
133 del state_dict[k]
134
135 # delete attn_mask since we always re-init it
136 attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
137 for k in attn_mask_keys:
138 del state_dict[k]
139
140 model_state_dict = model.state_dict()
141
142 # bicubic interpolate relative_position_bias_table if not match
143 relative_position_bias_table_keys = [
144 k for k in state_dict.keys() if "relative_position_bias_table" in k]
145 for k in relative_position_bias_table_keys:
146 relative_position_bias_table_pretrained = state_dict[k]
147 relative_position_bias_table_current = model_state_dict[k]
148 L1, nH1 = relative_position_bias_table_pretrained.size()
149 L2, nH2 = relative_position_bias_table_current.size()
150 if nH1 != nH2:
151 logger.warning(f"Error in loading {k}, passing......")
152 else:
153 if L1 != L2:
154 # bicubic interpolate relative_position_bias_table if not match
155 S1 = int(L1 ** 0.5)
156 S2 = int(L2 ** 0.5)
157 relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
158 relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
159 mode='bicubic')
160 state_dict[k] = relative_position_bias_table_pretrained_resized.view(
161 nH2, L2).permute(1, 0)
162
163 # bicubic interpolate attention_biases if not match
164 relative_position_bias_table_keys = [
165 k for k in state_dict.keys() if "attention_biases" in k]
166 for k in relative_position_bias_table_keys:
167 relative_position_bias_table_pretrained = state_dict[k]
168 relative_position_bias_table_current = model_state_dict[k]
169 nH1, L1 = relative_position_bias_table_pretrained.size()
170 nH2, L2 = relative_position_bias_table_current.size()
171 if nH1 != nH2:
172 logger.warning(f"Error in loading {k}, passing......")
173 else:
174 if L1 != L2:

Callers 1

mainFunction · 0.90

Calls 4

toMethod · 0.80
state_dictMethod · 0.45
sizeMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected