MCPcopy
hub / github.com/HIT-SCIR/ltp / update

Method update

python/core/ltp_core/models/metrics/token.py:95–133  ·  view source on GitHub ↗
(self, result: TokenClassifierResult, labels: Tensor, **kwargs)

Source from the content-addressed store, hash-verified

93 full_state_update: bool = False
94
95 def update(self, result: TokenClassifierResult, labels: Tensor, **kwargs) -> None:
96 crf = result.crf
97 logits = result.logits
98 attention_mask = result.attention_mask
99
100 # to expand
101 attention_mask = attention_mask.unsqueeze(-1).expand(-1, -1, attention_mask.size(1))
102 attention_mask = attention_mask & torch.transpose(attention_mask, -1, -2)
103 attention_mask = attention_mask.flatten(end_dim=1)
104
105 index = attention_mask[:, 0]
106 attention_mask = attention_mask[index]
107 logits = logits.flatten(end_dim=1)[index]
108 labels = labels.flatten(end_dim=1)[index]
109
110 labels = labels.cpu().numpy()
111 labels = [
112 [self.labels[tag] for tag, mask in zip(tags, masks) if mask]
113 for tags, masks in zip(labels, attention_mask)
114 ]
115
116 if crf is None:
117 decoded = logits.argmax(dim=-1)
118 decoded = decoded.cpu().numpy()
119 decoded = [
120 [self.labels[tag] for tag, mask in zip(tags, masks) if mask]
121 for tags, masks in zip(decoded, attention_mask)
122 ]
123 else:
124 logits = torch.log_softmax(logits, dim=-1)
125 decoded = crf.decode(logits, attention_mask)
126 decoded = [[self.labels[tag] for tag in tags] for tags in decoded]
127
128 gold_entities = get_entities(self.concat(labels))
129 pred_entities = get_entities(self.concat(decoded))
130
131 self.correct += len(set(gold_entities) & set(pred_entities))
132 self.gold_total += len(gold_entities)
133 self.pred_total += len(pred_entities)

Callers

nothing calls this directly

Calls 3

cpuMethod · 0.80
decodeMethod · 0.80
concatMethod · 0.80

Tested by

no test coverage detected