(model, ckpt_path)
| 67 | return res |
| 68 | |
| 69 | def load_checkpoint(model, ckpt_path): |
| 70 | checkpoint = torch.load(ckpt_path) |
| 71 | if 'model' in checkpoint: |
| 72 | checkpoint = checkpoint['model'] |
| 73 | if 'state_dict' in checkpoint: |
| 74 | checkpoint = checkpoint['state_dict'] |
| 75 | ckpt = {} |
| 76 | for k, v in checkpoint.items(): |
| 77 | if k.startswith('module.'): |
| 78 | ckpt[k[7:]] = v |
| 79 | else: |
| 80 | ckpt[k] = v |
| 81 | model.load_state_dict(ckpt) |
| 82 | |
| 83 | |
| 84 | class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): |
no test coverage detected