MCPcopy
hub / github.com/microsoft/Swin-Transformer / load_pretrained

Function load_pretrained

utils_moe.py:64–172  ·  view source on GitHub ↗
(config, model, logger)

Source from the content-addressed store, hash-verified

62
63
64def load_pretrained(config, model, logger):
65 global_rank = dist.get_rank()
66 logger.info(f"==============> Rank[{global_rank}] Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
67 if config.MODEL.PRETRAINED.endswith(f'.pth'):
68 if config.TRAIN.MOE.SAVE_MASTER:
69 pretrained_path = config.MODEL.PRETRAINED + f'.global'
70 else:
71 pretrained_path = config.MODEL.PRETRAINED + f'.rank{global_rank}'
72 logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {pretrained_path}......")
73 else:
74 pretrained_path = config.MODEL.PRETRAINED
75
76 if pretrained_path.endswith(f'.rank{global_rank}'):
77 checkpoint = torch.load(pretrained_path, map_location='cpu')
78 if os.path.exists(pretrained_path.replace(f'.rank{global_rank}', f'.master')):
79 checkpoint_master = torch.load(pretrained_path.replace(f'.rank{global_rank}', f'.master'),
80 map_location='cpu')
81 state_dict = merge_moe_model_state_dict(checkpoint['model'], checkpoint_master['model'])
82 else:
83 state_dict = checkpoint['model']
84 elif pretrained_path.endswith(f'.pth.global'):
85 checkpoint = torch.load(pretrained_path, map_location='cpu')
86 state_dict = checkpoint['model']
87 else:
88 raise NotImplementedError(f"{config.MODEL.PRETRAINED} file error...")
89
90 # delete relative_position_index since we always re-init it
91 relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
92 for k in relative_position_index_keys:
93 del state_dict[k]
94
95 # delete relative_coords_table since we always re-init it
96 relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
97 for k in relative_position_index_keys:
98 del state_dict[k]
99
100 # delete attn_mask since we always re-init it
101 attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
102 for k in attn_mask_keys:
103 del state_dict[k]
104
105 # bicubic interpolate relative_position_bias_table if not match
106 relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
107 for k in relative_position_bias_table_keys:
108 relative_position_bias_table_pretrained = state_dict[k]
109 relative_position_bias_table_current = model.state_dict()[k]
110 L1, nH1 = relative_position_bias_table_pretrained.size()
111 L2, nH2 = relative_position_bias_table_current.size()
112 if nH1 != nH2:
113 logger.warning(f"Error in loading {k}, passing......")
114 else:
115 if L1 != L2:
116 # bicubic interpolate relative_position_bias_table if not match
117 S1 = int(L1 ** 0.5)
118 S2 = int(L2 ** 0.5)
119 relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
120 relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
121 mode='bicubic')

Callers 1

mainFunction · 0.90

Calls 3

state_dictMethod · 0.80
load_state_dictMethod · 0.80

Tested by

no test coverage detected