(config, model, logger)
| 62 | |
| 63 | |
| 64 | def load_pretrained(config, model, logger): |
| 65 | global_rank = dist.get_rank() |
| 66 | logger.info(f"==============> Rank[{global_rank}] Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") |
| 67 | if config.MODEL.PRETRAINED.endswith(f'.pth'): |
| 68 | if config.TRAIN.MOE.SAVE_MASTER: |
| 69 | pretrained_path = config.MODEL.PRETRAINED + f'.global' |
| 70 | else: |
| 71 | pretrained_path = config.MODEL.PRETRAINED + f'.rank{global_rank}' |
| 72 | logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {pretrained_path}......") |
| 73 | else: |
| 74 | pretrained_path = config.MODEL.PRETRAINED |
| 75 | |
| 76 | if pretrained_path.endswith(f'.rank{global_rank}'): |
| 77 | checkpoint = torch.load(pretrained_path, map_location='cpu') |
| 78 | if os.path.exists(pretrained_path.replace(f'.rank{global_rank}', f'.master')): |
| 79 | checkpoint_master = torch.load(pretrained_path.replace(f'.rank{global_rank}', f'.master'), |
| 80 | map_location='cpu') |
| 81 | state_dict = merge_moe_model_state_dict(checkpoint['model'], checkpoint_master['model']) |
| 82 | else: |
| 83 | state_dict = checkpoint['model'] |
| 84 | elif pretrained_path.endswith(f'.pth.global'): |
| 85 | checkpoint = torch.load(pretrained_path, map_location='cpu') |
| 86 | state_dict = checkpoint['model'] |
| 87 | else: |
| 88 | raise NotImplementedError(f"{config.MODEL.PRETRAINED} file error...") |
| 89 | |
| 90 | # delete relative_position_index since we always re-init it |
| 91 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] |
| 92 | for k in relative_position_index_keys: |
| 93 | del state_dict[k] |
| 94 | |
| 95 | # delete relative_coords_table since we always re-init it |
| 96 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] |
| 97 | for k in relative_position_index_keys: |
| 98 | del state_dict[k] |
| 99 | |
| 100 | # delete attn_mask since we always re-init it |
| 101 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] |
| 102 | for k in attn_mask_keys: |
| 103 | del state_dict[k] |
| 104 | |
| 105 | # bicubic interpolate relative_position_bias_table if not match |
| 106 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] |
| 107 | for k in relative_position_bias_table_keys: |
| 108 | relative_position_bias_table_pretrained = state_dict[k] |
| 109 | relative_position_bias_table_current = model.state_dict()[k] |
| 110 | L1, nH1 = relative_position_bias_table_pretrained.size() |
| 111 | L2, nH2 = relative_position_bias_table_current.size() |
| 112 | if nH1 != nH2: |
| 113 | logger.warning(f"Error in loading {k}, passing......") |
| 114 | else: |
| 115 | if L1 != L2: |
| 116 | # bicubic interpolate relative_position_bias_table if not match |
| 117 | S1 = int(L1 ** 0.5) |
| 118 | S2 = int(L2 ** 0.5) |
| 119 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( |
| 120 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), |
| 121 | mode='bicubic') |
no test coverage detected