MCPcopy Index your code
hub / github.com/jindongwang/transferlearning / torch_load

Function torch_load

code/ASR/Adapter/utils.py:94–106  ·  view source on GitHub ↗
(snapshot_path, model, optimizer=None)

Source from the content-addressed store, hash-verified

92 state_dict = model.state_dict() if not optimizer else collections.OrderedDict(model=model.state_dict(), optimizer=optimizer.state_dict())
93 torch.save(state_dict, save_path)
94def torch_load(snapshot_path, model, optimizer=None):
95 # load snapshot
96 snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
97 if not "model" in snapshot_dict.keys():
98 model_dict = snapshot_dict
99 snapshot_dict = collections.OrderedDict(model=model_dict)
100 if hasattr(model, "module"):
101 model.module.load_state_dict(snapshot_dict["model"])
102 else:
103 model.load_state_dict(snapshot_dict["model"])
104 if optimizer:
105 optimizer.load_state_dict(snapshot_dict["optimizer"])
106 del snapshot_dict
107
108# Decoding
109def compute_wer(ref, hyp, normalize=False):

Callers 3

testFunction · 0.90
train.pyFile · 0.90
recognize_and_evaluateFunction · 0.70

Calls 2

loadMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected