MCPcopy
hub / github.com/microsoft/Cream / test_sample

Function test_sample

CDARTS/lib/core/pretrain_function.py:282–342  ·  view source on GitHub ↗
(valid_loader, model, epoch, cur_step, writer, logger, config)

Source from the content-addressed store, hash-verified

280
281
282def test_sample(valid_loader, model, epoch, cur_step, writer, logger, config):
283 top1 = utils.AverageMeter()
284 top5 = utils.AverageMeter()
285 losses = utils.AverageMeter()
286
287 model.eval()
288 device = torch.device("cuda")
289 criterion = nn.CrossEntropyLoss().to(device)
290
291
292 model.module.init_arch_params(layer_idx=0)
293 genotypes = []
294
295 for i in range(config.layer_num):
296 genotype, connect = model.module.generate_genotype(i)
297 genotypes.append(genotype)
298
299 model.module.genotypes[i] = genotype
300 model.module.connects[i] = connect
301
302 with torch.no_grad():
303 for step, (X, y) in enumerate(valid_loader):
304 X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
305 N = X.size(0)
306
307 # logits, _ = model(X, layer_idx=0, super_flag=True, pretrain_flag=True)
308 logits, _ = model(X, layer_idx=0, super_flag=True, pretrain_flag=True, is_slim=True)
309 loss = criterion(logits, y)
310
311 prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
312
313 if config.distributed:
314 reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
315 prec1 = utils.reduce_tensor(prec1, config.world_size)
316 prec5 = utils.reduce_tensor(prec5, config.world_size)
317 else:
318 reduced_loss = loss.data
319
320 losses.update(reduced_loss.item(), N)
321 top1.update(prec1.item(), N)
322 top5.update(prec5.item(), N)
323
324 torch.cuda.synchronize()
325 step_num = len(valid_loader)
326
327 if (step % config.print_freq == 0 or step == step_num-1) and config.local_rank == 0:
328 logger.info(
329 "Valid: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} "
330 "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
331 epoch+1, config.epochs, step, step_num,
332 losses=losses, top1=top1, top5=top5))
333
334 if config.local_rank == 0:
335 writer.add_scalar('val/loss', losses.avg, cur_step)
336 writer.add_scalar('val/top1', top1.avg, cur_step)
337 writer.add_scalar('val/top5', top5.avg, cur_step)
338
339 logger.info("Valid: Epoch {:2d}/{} Final Prec@1 {:.4%}".format(

Callers

nothing calls this directly

Calls 6

updateMethod · 0.95
toMethod · 0.80
init_arch_paramsMethod · 0.80
generate_genotypeMethod · 0.80
formatMethod · 0.80
sizeMethod · 0.45

Tested by

no test coverage detected