MCPcopy
hub / github.com/jindongwang/transferlearning / load_pretrained_model

Function load_pretrained_model

code/ASR/Adapter/utils.py:66–85  ·  view source on GitHub ↗

load_pretrained_model(model=model, model_path="", modules_to_load=None, exclude_modules="")

(model, model_path, modules_to_load=None, exclude_modules=None)

Source from the content-addressed store, hash-verified

64
65# Load and save
66def load_pretrained_model(model, model_path, modules_to_load=None, exclude_modules=None):
67 '''
68 load_pretrained_model(model=model, model_path="",
69 modules_to_load=None, exclude_modules="")
70 '''
71 model_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
72 if exclude_modules:
73 for e in exclude_modules.split(","):
74 model_dict = {k: v for k, v in model_dict.items() if not k.startswith(e)}
75
76 if not modules_to_load:
77 src_dict = model_dict
78 else:
79 src_dict = {}
80 for module in modules_to_load.split(","):
81 src_dict.update({k: v for k, v in model_dict.items() if k.startswith(module)})
82
83 dst_state = model.state_dict()
84 dst_state.update(src_dict)
85 model.load_state_dict(dst_state)
86def torch_save(model, save_path, optimizer=None, local_rank=0):
87 if local_rank != 0:
88 return

Callers 1

train.pyFile · 0.90

Calls 4

loadMethod · 0.45
updateMethod · 0.45
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected