Remove the DistributedDataParallel wrapper if present.
(model)
| 114 | amp = None |
| 115 | |
| 116 | def unwrap_model(model): |
| 117 | """Remove the DistributedDataParallel wrapper if present.""" |
| 118 | wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel) |
| 119 | return model.module if wrapped else model |
| 120 | |
| 121 | |
| 122 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger, model_ema=None): |
no outgoing calls
no test coverage detected