(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True)
| 373 | |
| 374 | |
| 375 | def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True): |
| 376 | global global_step |
| 377 | global global_epoch |
| 378 | |
| 379 | print("Load checkpoint from: {}".format(path)) |
| 380 | checkpoint = _load(path) |
| 381 | s = checkpoint["state_dict"] |
| 382 | new_s = {} |
| 383 | for k, v in s.items(): |
| 384 | new_s[k.replace('module.', '')] = v |
| 385 | model.load_state_dict(new_s) |
| 386 | if not reset_optimizer: |
| 387 | optimizer_state = checkpoint["optimizer"] |
| 388 | if optimizer_state is not None: |
| 389 | print("Load optimizer state from {}".format(path)) |
| 390 | optimizer.load_state_dict(checkpoint["optimizer"]) |
| 391 | if overwrite_global_states: |
| 392 | global_step = checkpoint["global_step"] |
| 393 | global_epoch = checkpoint["global_epoch"] |
| 394 | |
| 395 | return model |
| 396 | |
| 397 | if __name__ == "__main__": |
| 398 | checkpoint_dir = args.checkpoint_dir |
no test coverage detected