Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None.
(self, pretrained=None)
| 75 | |
| 76 | |
| 77 | def init_weights(self, pretrained=None): |
| 78 | """Initialize the weights in backbone. |
| 79 | |
| 80 | Args: |
| 81 | pretrained (str, optional): Path to pre-trained weights. |
| 82 | Defaults to None. |
| 83 | """ |
| 84 | |
| 85 | def _init_weights(m): |
| 86 | if isinstance(m, nn.Linear): |
| 87 | trunc_normal_(m.weight, std=.02) |
| 88 | if isinstance(m, nn.Linear) and m.bias is not None: |
| 89 | nn.init.constant_(m.bias, 0) |
| 90 | elif isinstance(m, nn.LayerNorm): |
| 91 | nn.init.constant_(m.bias, 0) |
| 92 | nn.init.constant_(m.weight, 1.0) |
| 93 | |
| 94 | if isinstance(pretrained, str): |
| 95 | self.apply(_init_weights) |
| 96 | logger = logging.getLogger(__name__) |
| 97 | |
| 98 | state_dict = torch.load(pretrained, map_location="cpu") |
| 99 | state_dict_model = state_dict["model"] |
| 100 | state_dict_model.pop("head.weight") |
| 101 | state_dict_model.pop("head.bias") |
| 102 | # pop rope |
| 103 | state_dict_model.pop("rope.freqs_cos") |
| 104 | state_dict_model.pop("rope.freqs_sin") |
| 105 | |
| 106 | if self.patch_embed.patch_size[-1] != state_dict["model"]["patch_embed.proj.weight"].shape[-1]: |
| 107 | state_dict_model.pop("patch_embed.proj.weight") |
| 108 | state_dict_model.pop("patch_embed.proj.bias") |
| 109 | interpolate_pos_embed(self, state_dict_model) |
| 110 | |
| 111 | res = self.load_state_dict(state_dict_model, strict=False) |
| 112 | logger.info(res) |
| 113 | print(res) |
| 114 | elif pretrained is None: |
| 115 | self.apply(_init_weights) |
| 116 | else: |
| 117 | raise TypeError('pretrained must be a str or None') |
| 118 | |
| 119 | def get_num_layers(self): |
| 120 | return len(self.layers) |
no test coverage detected