| 131 | |
| 132 | |
| 133 | def retrain_warmup(valid_loader, model, optimizer, epoch, writer, logger, super_flag, retrain_epochs, config): |
| 134 | |
| 135 | device = torch.device("cuda") |
| 136 | criterion = nn.CrossEntropyLoss().to(device) |
| 137 | top1 = utils.AverageMeter() |
| 138 | top5 = utils.AverageMeter() |
| 139 | losses = utils.AverageMeter() |
| 140 | |
| 141 | step_num = len(valid_loader) |
| 142 | step_num = int(step_num * config.sample_ratio) |
| 143 | |
| 144 | cur_step = epoch*step_num |
| 145 | cur_lr = optimizer.param_groups[0]['lr'] |
| 146 | if config.local_rank == 0: |
| 147 | logger.info("Warmup Epoch {} LR {:.3f}".format(epoch+1, cur_lr)) |
| 148 | writer.add_scalar('warmup/lr', cur_lr, cur_step) |
| 149 | |
| 150 | model.train() |
| 151 | |
| 152 | for step, (val_X, val_y) in enumerate(valid_loader): |
| 153 | if step > step_num: |
| 154 | break |
| 155 | |
| 156 | val_X, val_y = val_X.to(device, non_blocking=True), val_y.to(device, non_blocking=True) |
| 157 | N = val_X.size(0) |
| 158 | |
| 159 | optimizer.zero_grad() |
| 160 | logits_main, _ = model(val_X, super_flag=super_flag) |
| 161 | loss = criterion(logits_main, val_y) |
| 162 | loss.backward() |
| 163 | |
| 164 | nn.utils.clip_grad_norm_(model.module.parameters(), config.w_grad_clip) |
| 165 | optimizer.step() |
| 166 | |
| 167 | prec1, prec5 = utils.accuracy(logits_main, val_y, topk=(1, 5)) |
| 168 | if config.distributed: |
| 169 | reduced_loss = utils.reduce_tensor(loss.data, config.world_size) |
| 170 | prec1 = utils.reduce_tensor(prec1, config.world_size) |
| 171 | prec5 = utils.reduce_tensor(prec5, config.world_size) |
| 172 | |
| 173 | else: |
| 174 | reduced_loss = loss.data |
| 175 | |
| 176 | losses.update(reduced_loss.item(), N) |
| 177 | top1.update(prec1.item(), N) |
| 178 | top5.update(prec5.item(), N) |
| 179 | |
| 180 | torch.cuda.synchronize() |
| 181 | if config.local_rank == 0 and (step % config.print_freq == 0 or step == step_num): |
| 182 | logger.info( |
| 183 | "Warmup: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} " |
| 184 | "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( |
| 185 | epoch+1, retrain_epochs, step, |
| 186 | step_num, losses=losses, top1=top1, top5=top5)) |
| 187 | |
| 188 | if config.local_rank == 0: |
| 189 | writer.add_scalar('retrain/loss', reduced_loss.item(), cur_step) |
| 190 | writer.add_scalar('retrain/top1', prec1.item(), cur_step) |