(self, logits, labels, loss)
| 136 | return self.update(logits, labels, loss) |
| 137 | |
| 138 | def update(self, logits, labels, loss): |
| 139 | self.n_step += 1 |
| 140 | with torch.no_grad(): |
| 141 | shift_preds = logits[..., :-1, :].argmax(dim=-1) |
| 142 | shift_labels = labels[..., 1:] |
| 143 | self.right += (shift_preds == shift_labels).masked_fill(shift_labels.eq(-100), 0).sum().item() |
| 144 | self.total += (shift_labels != -100).sum().item() |
| 145 | self.total_loss += loss.item() |
| 146 | |
| 147 | def get_metric(self, reset=True): |
| 148 | dist.all_reduce(self.right, op=torch.distributed.ReduceOp.SUM) |
no outgoing calls
no test coverage detected