Computes the accuracy over the k top predictions for the specified values of k
(output, target, topk=(1,))
| 34 | |
| 35 | |
| 36 | def accuracy(output, target, topk=(1,)): |
| 37 | """Computes the accuracy over the k top predictions for the specified values of k""" |
| 38 | with torch.no_grad(): |
| 39 | maxk = max(topk) |
| 40 | batch_size = target.size(0) |
| 41 | |
| 42 | _, pred = output.topk(maxk, 1, True, True) |
| 43 | pred = pred.t() |
| 44 | correct = pred.eq(target.view(1, -1).expand_as(pred)) |
| 45 | |
| 46 | res = [] |
| 47 | for k in topk: |
| 48 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) |
| 49 | res.append(correct_k.mul_(100.0 / batch_size)) |
| 50 | return res |
| 51 | |
| 52 | |
| 53 | def adjust_learning_rate(args, optimizer, epoch): |