| 15 | import utils |
| 16 | |
| 17 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, |
| 18 | data_loader: Iterable, optimizer: torch.optim.Optimizer, |
| 19 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, |
| 20 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, |
| 21 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, |
| 22 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): |
| 23 | model.train(True) |
| 24 | metric_logger = utils.MetricLogger(delimiter=" ") |
| 25 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| 26 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| 27 | header = 'Epoch: [{}]'.format(epoch) |
| 28 | print_freq = 200 |
| 29 | |
| 30 | optimizer.zero_grad() |
| 31 | |
| 32 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): |
| 33 | step = data_iter_step // update_freq |
| 34 | if step >= num_training_steps_per_epoch: |
| 35 | continue |
| 36 | it = start_steps + step # global training iteration |
| 37 | # Update LR & WD for the first acc |
| 38 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: |
| 39 | for i, param_group in enumerate(optimizer.param_groups): |
| 40 | if lr_schedule_values is not None: |
| 41 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] |
| 42 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: |
| 43 | param_group["weight_decay"] = wd_schedule_values[it] |
| 44 | |
| 45 | samples = samples.to(device, non_blocking=True) |
| 46 | targets = targets.to(device, non_blocking=True) |
| 47 | |
| 48 | if mixup_fn is not None: |
| 49 | samples, targets = mixup_fn(samples, targets) |
| 50 | |
| 51 | if use_amp: |
| 52 | with torch.cuda.amp.autocast(): |
| 53 | output = model(samples) |
| 54 | loss = criterion(output, targets) |
| 55 | else: # full precision |
| 56 | output = model(samples) |
| 57 | loss = criterion(output, targets) |
| 58 | |
| 59 | loss_value = loss.item() |
| 60 | |
| 61 | if not math.isfinite(loss_value): # this could trigger if using AMP |
| 62 | print("Loss is {}, stopping training".format(loss_value)) |
| 63 | assert math.isfinite(loss_value) |
| 64 | |
| 65 | if use_amp: |
| 66 | # this attribute is added by timm on one optimizer (adahessian) |
| 67 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order |
| 68 | loss /= update_freq |
| 69 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, |
| 70 | parameters=model.parameters(), create_graph=is_second_order, |
| 71 | update_grad=(data_iter_step + 1) % update_freq == 0) |
| 72 | if (data_iter_step + 1) % update_freq == 0: |
| 73 | optimizer.zero_grad() |
| 74 | if model_ema is not None: |