Load checkpoint.
(args, depth_model, shift_model, focal_model)
| 25 | raise |
| 26 | |
| 27 | def load_ckpt(args, depth_model, shift_model, focal_model): |
| 28 | """ |
| 29 | Load checkpoint. |
| 30 | """ |
| 31 | if os.path.isfile(args.load_ckpt): |
| 32 | print("loading checkpoint %s" % args.load_ckpt) |
| 33 | checkpoint = torch.load(args.load_ckpt) |
| 34 | if shift_model is not None: |
| 35 | shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), |
| 36 | strict=True) |
| 37 | if focal_model is not None: |
| 38 | focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), |
| 39 | strict=True) |
| 40 | depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), |
| 41 | strict=True) |
| 42 | del checkpoint |
| 43 | torch.cuda.empty_cache() |
| 44 | |
| 45 | |
| 46 | def strip_prefix_if_present(state_dict, prefix): |
nothing calls this directly
no test coverage detected