Computes the precision@k for the specified values of k
(output, target, topk=(1,))
| 64 | |
| 65 | |
| 66 | def accuracy(output, target, topk=(1,)): |
| 67 | """ Computes the precision@k for the specified values of k """ |
| 68 | maxk = max(topk) |
| 69 | batch_size = target.size(0) |
| 70 | |
| 71 | _, pred = output.topk(maxk, 1, True, True) |
| 72 | pred = pred.t() |
| 73 | # one-hot case |
| 74 | if target.ndimension() > 1: |
| 75 | target = target.max(1)[1] |
| 76 | |
| 77 | correct = pred.eq(target.view(1, -1).expand_as(pred)) |
| 78 | |
| 79 | res = [] |
| 80 | for k in topk: |
| 81 | correct_k = correct[:k].view(-1).float().sum(0) |
| 82 | res.append(correct_k.mul_(1.0 / batch_size)) |
| 83 | |
| 84 | return res |
| 85 | |
| 86 | def save_checkpoint(state, ckpt_dir, is_best=False): |
| 87 | filename = os.path.join(ckpt_dir, 'checkpoint.pth.tar') |