(model, path)
| 150 | |
| 151 | |
| 152 | def load_weights(model, path): |
| 153 | checkpoint = torch.load(path, map_location='cpu') |
| 154 | if 'model' in checkpoint: |
| 155 | checkpoint = checkpoint['model'] |
| 156 | if 'state_dict' in checkpoint: |
| 157 | checkpoint = checkpoint['state_dict'] |
| 158 | unwrap_model(model).load_state_dict(checkpoint, strict=False) |
| 159 | print('=================== loaded from', path) |
| 160 | |
| 161 | def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None): |
| 162 | save_state = {'model': model.state_dict(), |