(snapshot_path, model, optimizer=None)
| 63 | state_dict = model.state_dict() if not optimizer else collections.OrderedDict(model=model.state_dict(), optimizer=optimizer.state_dict()) |
| 64 | torch.save(state_dict, save_path) |
| 65 | def torch_load(snapshot_path, model, optimizer=None): |
| 66 | # load snapshot |
| 67 | snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage) |
| 68 | if not "model" in snapshot_dict.keys(): |
| 69 | model_dict = snapshot_dict |
| 70 | snapshot_dict = collections.OrderedDict(model=model_dict) |
| 71 | if hasattr(model, "module"): |
| 72 | model.module.load_state_dict(snapshot_dict["model"]) |
| 73 | else: |
| 74 | model.load_state_dict(snapshot_dict["model"]) |
| 75 | if optimizer: |
| 76 | optimizer.load_state_dict(snapshot_dict["optimizer"]) |
| 77 | del snapshot_dict |
| 78 | |
| 79 | # Decoding |
| 80 | def compute_wer(ref, hyp, normalize=False): |
no test coverage detected