(valid_loader, model, epoch, cur_step, writer, logger, config)
| 280 | |
| 281 | |
| 282 | def test_sample(valid_loader, model, epoch, cur_step, writer, logger, config): |
| 283 | top1 = utils.AverageMeter() |
| 284 | top5 = utils.AverageMeter() |
| 285 | losses = utils.AverageMeter() |
| 286 | |
| 287 | model.eval() |
| 288 | device = torch.device("cuda") |
| 289 | criterion = nn.CrossEntropyLoss().to(device) |
| 290 | |
| 291 | |
| 292 | model.module.init_arch_params(layer_idx=0) |
| 293 | genotypes = [] |
| 294 | |
| 295 | for i in range(config.layer_num): |
| 296 | genotype, connect = model.module.generate_genotype(i) |
| 297 | genotypes.append(genotype) |
| 298 | |
| 299 | model.module.genotypes[i] = genotype |
| 300 | model.module.connects[i] = connect |
| 301 | |
| 302 | with torch.no_grad(): |
| 303 | for step, (X, y) in enumerate(valid_loader): |
| 304 | X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) |
| 305 | N = X.size(0) |
| 306 | |
| 307 | # logits, _ = model(X, layer_idx=0, super_flag=True, pretrain_flag=True) |
| 308 | logits, _ = model(X, layer_idx=0, super_flag=True, pretrain_flag=True, is_slim=True) |
| 309 | loss = criterion(logits, y) |
| 310 | |
| 311 | prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5)) |
| 312 | |
| 313 | if config.distributed: |
| 314 | reduced_loss = utils.reduce_tensor(loss.data, config.world_size) |
| 315 | prec1 = utils.reduce_tensor(prec1, config.world_size) |
| 316 | prec5 = utils.reduce_tensor(prec5, config.world_size) |
| 317 | else: |
| 318 | reduced_loss = loss.data |
| 319 | |
| 320 | losses.update(reduced_loss.item(), N) |
| 321 | top1.update(prec1.item(), N) |
| 322 | top5.update(prec5.item(), N) |
| 323 | |
| 324 | torch.cuda.synchronize() |
| 325 | step_num = len(valid_loader) |
| 326 | |
| 327 | if (step % config.print_freq == 0 or step == step_num-1) and config.local_rank == 0: |
| 328 | logger.info( |
| 329 | "Valid: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} " |
| 330 | "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( |
| 331 | epoch+1, config.epochs, step, step_num, |
| 332 | losses=losses, top1=top1, top5=top5)) |
| 333 | |
| 334 | if config.local_rank == 0: |
| 335 | writer.add_scalar('val/loss', losses.avg, cur_step) |
| 336 | writer.add_scalar('val/top1', top1.avg, cur_step) |
| 337 | writer.add_scalar('val/top5', top5.avg, cur_step) |
| 338 | |
| 339 | logger.info("Valid: Epoch {:2d}/{} Final Prec@1 {:.4%}".format( |
nothing calls this directly
no test coverage detected