(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler)
| 282 | |
| 283 | |
| 284 | def train_one_epoch_distill_using_saved_logits(args, config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler): |
| 285 | model.train() |
| 286 | set_bn_state(config, model) |
| 287 | optimizer.zero_grad() |
| 288 | |
| 289 | num_steps = len(data_loader) |
| 290 | batch_time = AverageMeter() |
| 291 | loss_meter = AverageMeter() |
| 292 | norm_meter = AverageMeter() |
| 293 | scaler_meter = AverageMeter() |
| 294 | meters = defaultdict(AverageMeter) |
| 295 | |
| 296 | start = time.time() |
| 297 | end = time.time() |
| 298 | data_tic = time.time() |
| 299 | |
| 300 | num_classes = config.MODEL.NUM_CLASSES |
| 301 | topk = config.DISTILL.LOGITS_TOPK |
| 302 | |
| 303 | for idx, ((samples, targets), (logits_index, logits_value, seeds)) in enumerate(data_loader): |
| 304 | normal_global_idx = epoch * NORM_ITER_LEN + \ |
| 305 | (idx * NORM_ITER_LEN // num_steps) |
| 306 | |
| 307 | samples = samples.cuda(non_blocking=True) |
| 308 | targets = targets.cuda(non_blocking=True) |
| 309 | |
| 310 | if mixup_fn is not None: |
| 311 | samples, targets = mixup_fn(samples, targets, seeds) |
| 312 | original_targets = targets.argmax(dim=1) |
| 313 | else: |
| 314 | original_targets = targets |
| 315 | meters['data_time'].update(time.time() - data_tic) |
| 316 | |
| 317 | with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): |
| 318 | outputs = model(samples) |
| 319 | |
| 320 | # recover teacher logits |
| 321 | logits_index = logits_index.long() |
| 322 | logits_value = logits_value.float() |
| 323 | logits_index = logits_index.cuda(non_blocking=True) |
| 324 | logits_value = logits_value.cuda(non_blocking=True) |
| 325 | minor_value = (1.0 - logits_value.sum(-1, keepdim=True) |
| 326 | ) / (num_classes - topk) |
| 327 | minor_value = minor_value.repeat_interleave(num_classes, dim=-1) |
| 328 | outputs_teacher = minor_value.scatter_(-1, logits_index, logits_value) |
| 329 | |
| 330 | loss = criterion(outputs, outputs_teacher) |
| 331 | loss = loss / config.TRAIN.ACCUMULATION_STEPS |
| 332 | |
| 333 | # this attribute is added by timm on one optimizer (adahessian) |
| 334 | is_second_order = hasattr( |
| 335 | optimizer, 'is_second_order') and optimizer.is_second_order |
| 336 | grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD, |
| 337 | parameters=model.parameters(), create_graph=is_second_order, |
| 338 | update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0) |
| 339 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: |
| 340 | optimizer.zero_grad() |
| 341 | lr_scheduler.step_update( |
no test coverage detected