| 5 | from models.loss import CrossEntropyLabelSmooth |
| 6 | |
| 7 | def train(train_loader, model, optimizer, epoch, writer, logger, config): |
| 8 | device = torch.device("cuda") |
| 9 | if config.label_smooth > 0: |
| 10 | criterion = CrossEntropyLabelSmooth(config.n_classes, config.label_smooth).to(device) |
| 11 | else: |
| 12 | criterion = nn.CrossEntropyLoss().to(device) |
| 13 | |
| 14 | top1 = utils.AverageMeter() |
| 15 | top5 = utils.AverageMeter() |
| 16 | losses = utils.AverageMeter() |
| 17 | |
| 18 | step_num = len(train_loader) |
| 19 | cur_step = epoch*step_num |
| 20 | cur_lr = optimizer.param_groups[0]['lr'] |
| 21 | if config.local_rank == 0: |
| 22 | logger.info("Train Epoch {} LR {}".format(epoch, cur_lr)) |
| 23 | writer.add_scalar('train/lr', cur_lr, cur_step) |
| 24 | |
| 25 | model.train() |
| 26 | |
| 27 | for step, (X, y) in enumerate(train_loader): |
| 28 | X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) |
| 29 | N = X.size(0) |
| 30 | |
| 31 | X, target_a, target_b, lam = data_utils.mixup_data(X, y, config.mixup_alpha, use_cuda=True) |
| 32 | |
| 33 | optimizer.zero_grad() |
| 34 | logits, logits_aux = model(X, layer_idx=0, super_flag=True, pretrain_flag=True) |
| 35 | loss = data_utils.mixup_criterion(criterion, logits, target_a, target_b, lam) |
| 36 | if config.aux_weight > 0: |
| 37 | # loss_aux = criterion(logits_aux, y) |
| 38 | loss_aux = data_utils.mixup_criterion(criterion, logits_aux, target_a, target_b, lam) |
| 39 | loss = loss + config.aux_weight * loss_aux |
| 40 | |
| 41 | if config.use_amp: |
| 42 | from apex import amp |
| 43 | with amp.scale_loss(loss, optimizer) as scaled_loss: |
| 44 | scaled_loss.backward() |
| 45 | else: |
| 46 | loss.backward() |
| 47 | # gradient clipping |
| 48 | nn.utils.clip_grad_norm_(model.module.parameters(), config.grad_clip) |
| 49 | optimizer.step() |
| 50 | |
| 51 | prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5)) |
| 52 | if config.distributed: |
| 53 | reduced_loss = utils.reduce_tensor(loss.data, config.world_size) |
| 54 | prec1 = utils.reduce_tensor(prec1, config.world_size) |
| 55 | prec5 = utils.reduce_tensor(prec5, config.world_size) |
| 56 | else: |
| 57 | reduced_loss = loss.data |
| 58 | |
| 59 | losses.update(reduced_loss.item(), N) |
| 60 | top1.update(prec1.item(), N) |
| 61 | top5.update(prec5.item(), N) |
| 62 | |
| 63 | torch.cuda.synchronize() |
| 64 | if config.local_rank == 0 and (step % config.print_freq == 0 or step == step_num): |