MCPcopy
hub / github.com/jindongwang/transferlearning / torch_load

Function torch_load

code/ASR/CMatch/utils.py:65–77  ·  view source on GitHub ↗
(snapshot_path, model, optimizer=None)

Source from the content-addressed store, hash-verified

63 state_dict = model.state_dict() if not optimizer else collections.OrderedDict(model=model.state_dict(), optimizer=optimizer.state_dict())
64 torch.save(state_dict, save_path)
65def torch_load(snapshot_path, model, optimizer=None):
66 # load snapshot
67 snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
68 if not "model" in snapshot_dict.keys():
69 model_dict = snapshot_dict
70 snapshot_dict = collections.OrderedDict(model=model_dict)
71 if hasattr(model, "module"):
72 model.module.load_state_dict(snapshot_dict["model"])
73 else:
74 model.load_state_dict(snapshot_dict["model"])
75 if optimizer:
76 optimizer.load_state_dict(snapshot_dict["optimizer"])
77 del snapshot_dict
78
79# Decoding
80def compute_wer(ref, hyp, normalize=False):

Callers 3

testFunction · 0.90
train.pyFile · 0.90
recognize_and_evaluateFunction · 0.70

Calls 2

loadMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected