Computes the accuracy over the k top predictions for the specified values of k
(output, target, topk=(1,))
| 22 | |
| 23 | |
| 24 | def accuracy(output, target, topk=(1,)): |
| 25 | """Computes the accuracy over the k top predictions for the specified values of k""" |
| 26 | with torch.no_grad(): |
| 27 | maxk = max(topk) |
| 28 | batch_size = target.size(0) |
| 29 | |
| 30 | _, pred = output.topk(maxk, 1, True, True) |
| 31 | pred = pred.t() |
| 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) |
| 33 | |
| 34 | res = [] |
| 35 | for k in topk: |
| 36 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) |
| 37 | res.append(correct_k.mul_(100.0 / batch_size)) |
| 38 | return res |