| 19 | m.eval() |
| 20 | |
| 21 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, |
| 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, |
| 23 | device: torch.device, epoch: int, loss_scaler, |
| 24 | clip_grad: float = 0, |
| 25 | clip_mode: str = 'norm', |
| 26 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, |
| 27 | set_training_mode=True, |
| 28 | set_bn_eval=False,): |
| 29 | model.train(set_training_mode) |
| 30 | if set_bn_eval: |
| 31 | set_bn_state(model) |
| 32 | metric_logger = utils.MetricLogger(delimiter=" ") |
| 33 | metric_logger.add_meter('lr', utils.SmoothedValue( |
| 34 | window_size=1, fmt='{value:.6f}')) |
| 35 | header = 'Epoch: [{}]'.format(epoch) |
| 36 | print_freq = 100 |
| 37 | |
| 38 | for samples, targets in metric_logger.log_every( |
| 39 | data_loader, print_freq, header): |
| 40 | samples = samples.to(device, non_blocking=True) |
| 41 | targets = targets.to(device, non_blocking=True) |
| 42 | |
| 43 | if mixup_fn is not None: |
| 44 | samples, targets = mixup_fn(samples, targets) |
| 45 | |
| 46 | if True: # with torch.cuda.amp.autocast(): |
| 47 | outputs = model(samples) |
| 48 | loss = criterion(samples, outputs, targets) |
| 49 | |
| 50 | loss_value = loss.item() |
| 51 | |
| 52 | if not math.isfinite(loss_value): |
| 53 | print("Loss is {}, stopping training".format(loss_value)) |
| 54 | sys.exit(1) |
| 55 | |
| 56 | optimizer.zero_grad() |
| 57 | |
| 58 | # this attribute is added by timm on one optimizer (adahessian) |
| 59 | is_second_order = hasattr( |
| 60 | optimizer, 'is_second_order') and optimizer.is_second_order |
| 61 | loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, |
| 62 | parameters=model.parameters(), create_graph=is_second_order) |
| 63 | |
| 64 | torch.cuda.synchronize() |
| 65 | if model_ema is not None: |
| 66 | model_ema.update(model) |
| 67 | |
| 68 | metric_logger.update(loss=loss_value) |
| 69 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) |
| 70 | # gather the stats from all processes |
| 71 | metric_logger.synchronize_between_processes() |
| 72 | print("Averaged stats:", metric_logger) |
| 73 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
| 74 | |
| 75 | |
| 76 | @torch.no_grad() |