MCPcopy Index your code
hub / github.com/microsoft/Cream / validate

Function validate

Cream/lib/core/test.py:13–87  ·  view source on GitHub ↗
(epoch, model, loader, loss_fn, cfg, log_suffix='', logger=None, writer=None, local_rank=0)

Source from the content-addressed store, hash-verified

11
12# validate function
13def validate(epoch, model, loader, loss_fn, cfg, log_suffix='', logger=None, writer=None, local_rank=0):
14 batch_time_m = AverageMeter()
15 losses_m = AverageMeter()
16 prec1_m = AverageMeter()
17 prec5_m = AverageMeter()
18
19 model.eval()
20
21 end = time.time()
22 last_idx = len(loader) - 1
23 with torch.no_grad():
24 for batch_idx, (input, target) in enumerate(loader):
25 last_batch = batch_idx == last_idx
26
27 output = model(input)
28 if isinstance(output, (tuple, list)):
29 output = output[0]
30
31 # augmentation reduction
32 reduce_factor = cfg.TTA
33 if reduce_factor > 1:
34 output = output.unfold(
35 0,
36 reduce_factor,
37 reduce_factor).mean(
38 dim=2)
39 target = target[0:target.size(0):reduce_factor]
40
41 loss = loss_fn(output, target)
42 prec1, prec5 = accuracy(output, target, topk=(1, 5))
43
44 if cfg.NUM_GPU > 1:
45 reduced_loss = reduce_tensor(loss.data, cfg.NUM_GPU)
46 prec1 = reduce_tensor(prec1, cfg.NUM_GPU)
47 prec5 = reduce_tensor(prec5, cfg.NUM_GPU)
48 else:
49 reduced_loss = loss.data
50
51 torch.cuda.synchronize()
52
53 losses_m.update(reduced_loss.item(), input.size(0))
54 prec1_m.update(prec1.item(), output.size(0))
55 prec5_m.update(prec5.item(), output.size(0))
56
57 batch_time_m.update(time.time() - end)
58 end = time.time()
59 if local_rank == 0 and (last_batch or batch_idx % cfg.LOG_INTERVAL == 0):
60 log_name = 'Test' + log_suffix
61 logger.info(
62 '{0}: [{1:>4d}/{2}] '
63 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
64 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
65 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
66 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
67 log_name, batch_idx, last_idx,
68 batch_time=batch_time_m, loss=losses_m,
69 top1=prec1_m, top5=prec5_m))
70

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 7

updateMethod · 0.95
AverageMeterClass · 0.90
accuracyFunction · 0.90
reduce_tensorFunction · 0.90
loss_fnFunction · 0.85
formatMethod · 0.80
sizeMethod · 0.45

Tested by

no test coverage detected