MCPcopy
hub / github.com/jindongwang/transferlearning / load_head_from_pretrained_model

Function load_head_from_pretrained_model

code/ASR/Adapter/utils.py:10–20  ·  view source on GitHub ↗
(model, model_path)

Source from the content-addressed store, hash-verified

8
9
10def load_head_from_pretrained_model(model, model_path):
11 model_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
12 if "model" in model_dict.keys():
13 model_dict = model_dict["model"]
14 src_dict = {k: v for k, v in model_dict.items() if "decoder.embed." in k or "ctc." in k or "decoder.output_layer." in k}
15 dst_state = model.state_dict()
16 dst_state.update(src_dict)
17 for key in dst_state.keys():
18 if key in src_dict.keys():
19 logging.info("loading " + key)
20 model.load_state_dict(dst_state)
21
22def load_adapter_from_pretrained_model(model, model_path, src_adapter, tgt_adapter):
23 '''

Callers 1

train.pyFile · 0.90

Calls 4

loadMethod · 0.45
state_dictMethod · 0.45
updateMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected