(snapshot_path, model, optimizer=None)
| 92 | state_dict = model.state_dict() if not optimizer else collections.OrderedDict(model=model.state_dict(), optimizer=optimizer.state_dict()) |
| 93 | torch.save(state_dict, save_path) |
| 94 | def torch_load(snapshot_path, model, optimizer=None): |
| 95 | # load snapshot |
| 96 | snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage) |
| 97 | if not "model" in snapshot_dict.keys(): |
| 98 | model_dict = snapshot_dict |
| 99 | snapshot_dict = collections.OrderedDict(model=model_dict) |
| 100 | if hasattr(model, "module"): |
| 101 | model.module.load_state_dict(snapshot_dict["model"]) |
| 102 | else: |
| 103 | model.load_state_dict(snapshot_dict["model"]) |
| 104 | if optimizer: |
| 105 | optimizer.load_state_dict(snapshot_dict["optimizer"]) |
| 106 | del snapshot_dict |
| 107 | |
| 108 | # Decoding |
| 109 | def compute_wer(ref, hyp, normalize=False): |
no test coverage detected