| 196 | epoch+1, retrain_epochs, top1.avg)) |
| 197 | |
| 198 | def validate(valid_loader, model, epoch, cur_step, writer, logger, super_flag, config): |
| 199 | top1 = utils.AverageMeter() |
| 200 | top5 = utils.AverageMeter() |
| 201 | losses = utils.AverageMeter() |
| 202 | |
| 203 | model.eval() |
| 204 | device = torch.device("cuda") |
| 205 | criterion = nn.CrossEntropyLoss().to(device) |
| 206 | |
| 207 | with torch.no_grad(): |
| 208 | for step, (X, y) in enumerate(valid_loader): |
| 209 | X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) |
| 210 | N = X.size(0) |
| 211 | |
| 212 | logits, _ = model(X, super_flag=False) |
| 213 | loss = criterion(logits, y) |
| 214 | |
| 215 | prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5)) |
| 216 | |
| 217 | reduced_loss = loss.data |
| 218 | |
| 219 | losses.update(reduced_loss.item(), N) |
| 220 | top1.update(prec1.item(), N) |
| 221 | top5.update(prec5.item(), N) |
| 222 | |
| 223 | torch.cuda.synchronize() |
| 224 | step_num = len(valid_loader) |
| 225 | |
| 226 | if (step % config.print_freq == 0 or step == step_num-1) and config.local_rank == 0: |
| 227 | logger.info( |
| 228 | "Valid: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} " |
| 229 | "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( |
| 230 | epoch+1, config.search_iter*config.search_iter_epochs, step, step_num, |
| 231 | losses=losses, top1=top1, top5=top5)) |
| 232 | |
| 233 | if config.local_rank == 0: |
| 234 | writer.add_scalar('val/loss', losses.avg, cur_step) |
| 235 | writer.add_scalar('val/top1', top1.avg, cur_step) |
| 236 | writer.add_scalar('val/top5', top5.avg, cur_step) |
| 237 | |
| 238 | logger.info("Valid: Epoch {:2d}/{} Final Prec@1 {:.4%}".format( |
| 239 | epoch+1, config.search_iter*config.search_iter_epochs, top1.avg)) |
| 240 | |
| 241 | return top1.avg |