MCPcopy
hub / github.com/hustvl/Vim / init_weights

Method init_weights

det/detectron2/modeling/backbone/vim.py:77–117  ·  view source on GitHub ↗

Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None.

(self, pretrained=None)

Source from the content-addressed store, hash-verified

75
76
77 def init_weights(self, pretrained=None):
78 """Initialize the weights in backbone.
79
80 Args:
81 pretrained (str, optional): Path to pre-trained weights.
82 Defaults to None.
83 """
84
85 def _init_weights(m):
86 if isinstance(m, nn.Linear):
87 trunc_normal_(m.weight, std=.02)
88 if isinstance(m, nn.Linear) and m.bias is not None:
89 nn.init.constant_(m.bias, 0)
90 elif isinstance(m, nn.LayerNorm):
91 nn.init.constant_(m.bias, 0)
92 nn.init.constant_(m.weight, 1.0)
93
94 if isinstance(pretrained, str):
95 self.apply(_init_weights)
96 logger = logging.getLogger(__name__)
97
98 state_dict = torch.load(pretrained, map_location="cpu")
99 state_dict_model = state_dict["model"]
100 state_dict_model.pop("head.weight")
101 state_dict_model.pop("head.bias")
102 # pop rope
103 state_dict_model.pop("rope.freqs_cos")
104 state_dict_model.pop("rope.freqs_sin")
105
106 if self.patch_embed.patch_size[-1] != state_dict["model"]["patch_embed.proj.weight"].shape[-1]:
107 state_dict_model.pop("patch_embed.proj.weight")
108 state_dict_model.pop("patch_embed.proj.bias")
109 interpolate_pos_embed(self, state_dict_model)
110
111 res = self.load_state_dict(state_dict_model, strict=False)
112 logger.info(res)
113 print(res)
114 elif pretrained is None:
115 self.apply(_init_weights)
116 else:
117 raise TypeError('pretrained must be a str or None')
118
119 def get_num_layers(self):
120 return len(self.layers)

Callers 3

__init__Method · 0.95
__init__Method · 0.45
__init__Method · 0.45

Calls 4

interpolate_pos_embedFunction · 0.90
printFunction · 0.85
loadMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected