| 18 | |
| 19 | |
| 20 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, |
| 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, |
| 22 | device: torch.device, epoch: int, loss_scaler, amp_autocast, max_norm: float = 0, |
| 23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, |
| 24 | set_training_mode=True, args = None): |
| 25 | model.train(set_training_mode) |
| 26 | metric_logger = utils.MetricLogger(delimiter=" ") |
| 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
| 28 | header = 'Epoch: [{}]'.format(epoch) |
| 29 | print_freq = 10 |
| 30 | |
| 31 | if args.cosub: |
| 32 | criterion = torch.nn.BCEWithLogitsLoss() |
| 33 | |
| 34 | # debug |
| 35 | # count = 0 |
| 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): |
| 37 | # count += 1 |
| 38 | # if count > 20: |
| 39 | # break |
| 40 | |
| 41 | samples = samples.to(device, non_blocking=True) |
| 42 | targets = targets.to(device, non_blocking=True) |
| 43 | |
| 44 | if mixup_fn is not None: |
| 45 | samples, targets = mixup_fn(samples, targets) |
| 46 | |
| 47 | if args.cosub: |
| 48 | samples = torch.cat((samples,samples),dim=0) |
| 49 | |
| 50 | if args.bce_loss: |
| 51 | targets = targets.gt(0.0).type(targets.dtype) |
| 52 | |
| 53 | with amp_autocast(): |
| 54 | outputs = model(samples, if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank) |
| 55 | # outputs = model(samples) |
| 56 | if not args.cosub: |
| 57 | loss = criterion(samples, outputs, targets) |
| 58 | else: |
| 59 | outputs = torch.split(outputs, outputs.shape[0]//2, dim=0) |
| 60 | loss = 0.25 * criterion(outputs[0], targets) |
| 61 | loss = loss + 0.25 * criterion(outputs[1], targets) |
| 62 | loss = loss + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid()) |
| 63 | loss = loss + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) |
| 64 | |
| 65 | if args.if_nan2num: |
| 66 | with amp_autocast(): |
| 67 | loss = torch.nan_to_num(loss) |
| 68 | |
| 69 | loss_value = loss.item() |
| 70 | |
| 71 | if not math.isfinite(loss_value): |
| 72 | print("Loss is {}, stopping training".format(loss_value)) |
| 73 | if args.if_continue_inf: |
| 74 | optimizer.zero_grad() |
| 75 | continue |
| 76 | else: |
| 77 | sys.exit(1) |