(checkpoint_path)
| 364 | print("Saved checkpoint:", checkpoint_path) |
| 365 | |
| 366 | def _load(checkpoint_path): |
| 367 | if use_cuda: |
| 368 | checkpoint = torch.load(checkpoint_path) |
| 369 | else: |
| 370 | checkpoint = torch.load(checkpoint_path, |
| 371 | map_location=lambda storage, loc: storage) |
| 372 | return checkpoint |
| 373 | |
| 374 | |
| 375 | def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True): |