| 19 | self.add_state("total", default=tensor(0), dist_reduce_fx="sum") |
| 20 | |
| 21 | def update(self, result: GraphResult, head: Tensor, labels: Tensor): |
| 22 | s_arc = result.arc_logits.detach() |
| 23 | s_rel = result.rel_logits.detach() |
| 24 | attention_mask = result.attention_mask |
| 25 | |
| 26 | # mask padding 的部分 |
| 27 | word_cls_mask = torch.cat([attention_mask[:, :1], attention_mask], dim=1) |
| 28 | activate_word_mask = word_cls_mask.unsqueeze(-1).expand_as(s_arc) |
| 29 | activate_word_mask = activate_word_mask & activate_word_mask.transpose(-1, -2) |
| 30 | s_arc.masked_fill_(~activate_word_mask, float("-inf")) |
| 31 | |
| 32 | # mask root 和 对角线部分 |
| 33 | s_arc[:, 0, 1:] = float("-inf") |
| 34 | s_arc.diagonal(0, 1, 2).fill_(float("-inf")) |
| 35 | |
| 36 | s_arc = s_arc.view(-1).cpu().numpy() |
| 37 | length = torch.sum(word_cls_mask, dim=1).cpu().numpy() |
| 38 | arcs = [tensor(sequence, device=self.device) for sequence in eisner(s_arc, length, True)] |
| 39 | arcs = torch.nn.utils.rnn.pad_sequence(arcs, batch_first=True, padding_value=0) |
| 40 | |
| 41 | rels = torch.argmax(s_rel[:, 1:], dim=-1) |
| 42 | rels = rels.gather(-1, arcs.unsqueeze(-1)).squeeze(-1) |
| 43 | |
| 44 | # todo: UAS, now only LAS |
| 45 | arc_correct = arcs == head |
| 46 | rel_correct = rels == labels |
| 47 | self.correct += (arc_correct & rel_correct)[attention_mask].sum().item() |
| 48 | self.total += torch.sum(attention_mask).item() |
| 49 | |
| 50 | def compute(self) -> Any: |
| 51 | return self.correct / (self.total + 1e-6) |