| 87 | self.avg = self.sum / self.count |
| 88 | |
| 89 | def all_reduce(self): |
| 90 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 91 | if isinstance(self.sum, np.ndarray): |
| 92 | total = torch.tensor( |
| 93 | self.sum.tolist() |
| 94 | + [ |
| 95 | self.count, |
| 96 | ], |
| 97 | dtype=torch.float32, |
| 98 | device=device, |
| 99 | ) |
| 100 | else: |
| 101 | total = torch.tensor( |
| 102 | [self.sum, self.count], dtype=torch.float32, device=device |
| 103 | ) |
| 104 | |
| 105 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) |
| 106 | if total.shape[0] > 2: |
| 107 | self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() |
| 108 | else: |
| 109 | self.sum, self.count = total.tolist() |
| 110 | self.avg = self.sum / (self.count + 1e-5) |
| 111 | |
| 112 | def __str__(self): |
| 113 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |