(self, result: GraphResult, head: Tensor, labels: Tensor)
| 63 | self.add_state("gold_total", default=tensor(0), dist_reduce_fx="sum") |
| 64 | |
| 65 | def update(self, result: GraphResult, head: Tensor, labels: Tensor): |
| 66 | s_arc = result.arc_logits.detach() |
| 67 | s_rel = result.rel_logits.detach() |
| 68 | attention_mask = result.attention_mask |
| 69 | |
| 70 | # mask padding 的部分 |
| 71 | activate_word_mask = torch.cat([attention_mask[:, :1], attention_mask], dim=1) |
| 72 | activate_word_mask = activate_word_mask.unsqueeze(-1).expand_as(s_arc) |
| 73 | activate_word_mask = activate_word_mask & activate_word_mask.transpose(-1, -2) |
| 74 | s_arc = s_arc.masked_fill(~activate_word_mask, float("-inf")) |
| 75 | |
| 76 | # mask root 和 对角线部分 |
| 77 | s_arc[:, 0, 1:] = float("-inf") |
| 78 | s_arc.diagonal(0, 1, 2).fill_(float("-inf")) |
| 79 | |
| 80 | arcs = s_arc[:, 1:, :] > 0 |
| 81 | rels = torch.argmax(s_rel[:, 1:, :], dim=-1) |
| 82 | |
| 83 | pred_entities = self.get_graph_entities(arcs, rels, flatten=True) |
| 84 | gold_entities = self.get_graph_entities(head, labels, flatten=True) |
| 85 | |
| 86 | self.correct += len(set(gold_entities) & set(pred_entities)) |
| 87 | self.gold_total += len(gold_entities) |
| 88 | self.pred_total += len(pred_entities) |
| 89 | |
| 90 | def compute(self) -> Any: |
| 91 | return 2 * self.correct / (self.pred_total + self.gold_total + 1e-6) |
nothing calls this directly
no test coverage detected