Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path.
(model_name)
| 272 | return video_grid |
| 273 | |
| 274 | def find_model(model_name): |
| 275 | """ |
| 276 | Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. |
| 277 | """ |
| 278 | assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' |
| 279 | checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) |
| 280 | |
| 281 | if "ema" in checkpoint: # supports checkpoints from train.py |
| 282 | print('Using Ema!') |
| 283 | checkpoint = checkpoint["ema"] |
| 284 | else: |
| 285 | print('Using model!') |
| 286 | checkpoint = checkpoint['model'] |
| 287 | return checkpoint |
| 288 | |
| 289 | ################################################################################# |
| 290 | # MMCV Utils # |