MCPcopy
hub / github.com/HIT-SCIR/ltp / update

Method update

python/core/ltp_core/models/metrics/token.py:47–76  ·  view source on GitHub ↗
(self, result: TokenClassifierResult, labels: Tensor, **kwargs)

Source from the content-addressed store, hash-verified

45 self.add_state("pred_total", default=tensor(0), dist_reduce_fx="sum")
46
47 def update(self, result: TokenClassifierResult, labels: Tensor, **kwargs) -> None:
48 crf = result.crf
49 logits = result.logits
50 attention_mask = result.attention_mask
51
52 labels = labels.cpu().numpy()
53 labels = [
54 [self.labels[tag] for tag, mask in zip(tags, masks) if mask]
55 for tags, masks in zip(labels, attention_mask)
56 ]
57
58 if crf is None:
59 decoded = logits.argmax(dim=-1)
60 decoded = decoded.cpu().numpy()
61 attention_mask = attention_mask.cpu().numpy()
62 decoded = [
63 [self.labels[tag] for tag, mask in zip(tags, masks) if mask]
64 for tags, masks in zip(decoded, attention_mask)
65 ]
66 else:
67 logits = torch.log_softmax(logits, dim=-1)
68 decoded = crf.decode(logits, attention_mask)
69 decoded = [[self.labels[tag] for tag in tags] for tags in decoded]
70
71 gold_entities = get_entities(self.concat(labels))
72 pred_entities = get_entities(self.concat(decoded))
73
74 self.correct += len(set(gold_entities) & set(pred_entities))
75 self.gold_total += len(gold_entities)
76 self.pred_total += len(pred_entities)
77
78 def compute(self) -> Any:
79 return 2 * self.correct / (self.pred_total + self.gold_total + 1e-6)

Callers 1

updateMethod · 0.45

Calls 3

concatMethod · 0.95
cpuMethod · 0.80
decodeMethod · 0.80

Tested by

no test coverage detected