| 34 | |
| 35 | |
| 36 | class ProgressMeter(object): |
| 37 | def __init__(self, num_batches, meters, prefix=""): |
| 38 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
| 39 | self.meters = meters |
| 40 | self.prefix = prefix |
| 41 | |
| 42 | def display(self, batch): |
| 43 | entries = [self.prefix + self.batch_fmtstr.format(batch)] |
| 44 | entries += [str(meter) for meter in self.meters] |
| 45 | print('\t'.join(entries)) |
| 46 | |
| 47 | def _get_batch_fmtstr(self, num_batches): |
| 48 | num_digits = len(str(num_batches // 1)) |
| 49 | fmt = '{:' + str(num_digits) + 'd}' |
| 50 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
| 51 | |
| 52 | |
| 53 | def accuracy(output, target, topk=(1,)): |