MCPcopy
hub / github.com/HIT-SCIR/ltp / update

Method update

python/core/ltp_core/models/metrics/graph.py:21–48  ·  view source on GitHub ↗
(self, result: GraphResult, head: Tensor, labels: Tensor)

Source from the content-addressed store, hash-verified

19 self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
20
21 def update(self, result: GraphResult, head: Tensor, labels: Tensor):
22 s_arc = result.arc_logits.detach()
23 s_rel = result.rel_logits.detach()
24 attention_mask = result.attention_mask
25
26 # mask padding 的部分
27 word_cls_mask = torch.cat([attention_mask[:, :1], attention_mask], dim=1)
28 activate_word_mask = word_cls_mask.unsqueeze(-1).expand_as(s_arc)
29 activate_word_mask = activate_word_mask & activate_word_mask.transpose(-1, -2)
30 s_arc.masked_fill_(~activate_word_mask, float("-inf"))
31
32 # mask root 和 对角线部分
33 s_arc[:, 0, 1:] = float("-inf")
34 s_arc.diagonal(0, 1, 2).fill_(float("-inf"))
35
36 s_arc = s_arc.view(-1).cpu().numpy()
37 length = torch.sum(word_cls_mask, dim=1).cpu().numpy()
38 arcs = [tensor(sequence, device=self.device) for sequence in eisner(s_arc, length, True)]
39 arcs = torch.nn.utils.rnn.pad_sequence(arcs, batch_first=True, padding_value=0)
40
41 rels = torch.argmax(s_rel[:, 1:], dim=-1)
42 rels = rels.gather(-1, arcs.unsqueeze(-1)).squeeze(-1)
43
44 # todo: UAS, now only LAS
45 arc_correct = arcs == head
46 rel_correct = rels == labels
47 self.correct += (arc_correct & rel_correct)[attention_mask].sum().item()
48 self.total += torch.sum(attention_mask).item()
49
50 def compute(self) -> Any:
51 return self.correct / (self.total + 1e-6)

Callers 6

validation_stepMethod · 0.45
validation_epoch_endMethod · 0.45
test_stepMethod · 0.45
build_vocabsFunction · 0.45
build_vocabsFunction · 0.45
build_vocabsFunction · 0.45

Calls 2

eisnerFunction · 0.85
cpuMethod · 0.80

Tested by 1

test_stepMethod · 0.36