(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler)
| 193 | |
| 194 | |
| 195 | def train_one_epoch(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): |
| 196 | model.train() |
| 197 | set_bn_state(config, model) |
| 198 | optimizer.zero_grad() |
| 199 | |
| 200 | num_steps = len(data_loader) |
| 201 | batch_time = AverageMeter() |
| 202 | loss_meter = AverageMeter() |
| 203 | norm_meter = AverageMeter() |
| 204 | scaler_meter = AverageMeter() |
| 205 | acc1_meter = AverageMeter() |
| 206 | acc5_meter = AverageMeter() |
| 207 | |
| 208 | start = time.time() |
| 209 | end = time.time() |
| 210 | for idx, (samples, targets) in enumerate(data_loader): |
| 211 | normal_global_idx = epoch * NORM_ITER_LEN + \ |
| 212 | (idx * NORM_ITER_LEN // num_steps) |
| 213 | |
| 214 | samples = samples.cuda(non_blocking=True) |
| 215 | targets = targets.cuda(non_blocking=True) |
| 216 | |
| 217 | if mixup_fn is not None: |
| 218 | samples, targets = mixup_fn(samples, targets) |
| 219 | original_targets = targets.argmax(dim=1) |
| 220 | else: |
| 221 | original_targets = targets |
| 222 | |
| 223 | with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): |
| 224 | outputs = model(samples) |
| 225 | |
| 226 | loss = criterion(outputs, targets) |
| 227 | loss = loss / config.TRAIN.ACCUMULATION_STEPS |
| 228 | |
| 229 | # this attribute is added by timm on one optimizer (adahessian) |
| 230 | is_second_order = hasattr( |
| 231 | optimizer, 'is_second_order') and optimizer.is_second_order |
| 232 | grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, |
| 233 | parameters=model.parameters(), create_graph=is_second_order, |
| 234 | update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) |
| 235 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: |
| 236 | optimizer.zero_grad() |
| 237 | lr_scheduler.step_update( |
| 238 | (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) |
| 239 | loss_scale_value = loss_scaler.state_dict().get("scale", 1.0) |
| 240 | |
| 241 | with torch.no_grad(): |
| 242 | acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5)) |
| 243 | acc1_meter.update(acc1.item(), targets.size(0)) |
| 244 | acc5_meter.update(acc5.item(), targets.size(0)) |
| 245 | |
| 246 | torch.cuda.synchronize() |
| 247 | |
| 248 | loss_meter.update(loss.item(), targets.size(0)) |
| 249 | if is_valid_grad_norm(grad_norm): |
| 250 | norm_meter.update(grad_norm) |
| 251 | scaler_meter.update(loss_scale_value) |
| 252 | batch_time.update(time.time() - end) |
no test coverage detected