MCPcopy
hub / github.com/HIT-SCIR/ltp / update

Method update

python/core/ltp_core/models/metrics/graph.py:65–88  ·  view source on GitHub ↗
(self, result: GraphResult, head: Tensor, labels: Tensor)

Source from the content-addressed store, hash-verified

63 self.add_state("gold_total", default=tensor(0), dist_reduce_fx="sum")
64
65 def update(self, result: GraphResult, head: Tensor, labels: Tensor):
66 s_arc = result.arc_logits.detach()
67 s_rel = result.rel_logits.detach()
68 attention_mask = result.attention_mask
69
70 # mask padding 的部分
71 activate_word_mask = torch.cat([attention_mask[:, :1], attention_mask], dim=1)
72 activate_word_mask = activate_word_mask.unsqueeze(-1).expand_as(s_arc)
73 activate_word_mask = activate_word_mask & activate_word_mask.transpose(-1, -2)
74 s_arc = s_arc.masked_fill(~activate_word_mask, float("-inf"))
75
76 # mask root 和 对角线部分
77 s_arc[:, 0, 1:] = float("-inf")
78 s_arc.diagonal(0, 1, 2).fill_(float("-inf"))
79
80 arcs = s_arc[:, 1:, :] > 0
81 rels = torch.argmax(s_rel[:, 1:, :], dim=-1)
82
83 pred_entities = self.get_graph_entities(arcs, rels, flatten=True)
84 gold_entities = self.get_graph_entities(head, labels, flatten=True)
85
86 self.correct += len(set(gold_entities) & set(pred_entities))
87 self.gold_total += len(gold_entities)
88 self.pred_total += len(pred_entities)
89
90 def compute(self) -> Any:
91 return 2 * self.correct / (self.pred_total + self.gold_total + 1e-6)

Callers

nothing calls this directly

Calls 1

get_graph_entitiesMethod · 0.95

Tested by

no test coverage detected