MCPcopy
hub / github.com/jindongwang/transferlearning / load_pretrained_model

Function load_pretrained_model

code/ASR/CMatch/utils.py:37–56  ·  view source on GitHub ↗

load_pretrained_model(model=model, model_path="", modules_to_load=None, exclude_modules="")

(model, model_path, modules_to_load=None, exclude_modules=None)

Source from the content-addressed store, hash-verified

35
36# Load and save
37def load_pretrained_model(model, model_path, modules_to_load=None, exclude_modules=None):
38 '''
39 load_pretrained_model(model=model, model_path="",
40 modules_to_load=None, exclude_modules="")
41 '''
42 model_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
43 if exclude_modules:
44 for e in exclude_modules.split(","):
45 model_dict = {k: v for k, v in model_dict.items() if not k.startswith(e)}
46
47 if not modules_to_load:
48 src_dict = model_dict
49 else:
50 src_dict = {}
51 for module in modules_to_load.split(","):
52 src_dict.update({k: v for k, v in model_dict.items() if k.startswith(module)})
53
54 dst_state = model.state_dict()
55 dst_state.update(src_dict)
56 model.load_state_dict(dst_state)
57def torch_save(model, save_path, optimizer=None, local_rank=0):
58 if local_rank != 0:
59 return

Callers 1

train.pyFile · 0.90

Calls 4

loadMethod · 0.45
updateMethod · 0.45
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected