MCPcopy
hub / github.com/HIT-SCIR/ltp / _cws_post

Method _cws_post

python/interface/ltp/nerual.py:179–256  ·  view source on GitHub ↗
(
        self,
        result: TokenClassifierResult,
        hidden: Dict[str, torch.Tensor],
        store: Dict[str, Any],
        inputs: List[str] = None,
        tokenized: BatchEncoding = None,
    )

Source from the content-addressed store, hash-verified

177
178 @no_grad
179 def _cws_post(
180 self,
181 result: TokenClassifierResult,
182 hidden: Dict[str, torch.Tensor],
183 store: Dict[str, Any],
184 inputs: List[str] = None,
185 tokenized: BatchEncoding = None,
186 ) -> LTPOutput:
187 crf = result.crf
188 logits = result.logits
189 attention_mask = result.attention_mask
190
191 text = []
192 char_idx = []
193 for raw_text, encodings in zip(inputs, tokenized.encodings):
194 last = None
195 text.append([])
196 char_idx.append([])
197 for idx, current in enumerate(encodings.offsets[1:-1]):
198 if current == (0, 0):
199 break
200 elif current[0] == current[1]:
201 continue
202 elif current != last:
203 text[-1].append(raw_text[current[0] : current[1]])
204 char_idx[-1].append(idx)
205 last = current
206 text = ["".join(t) for t in text]
207
208 if crf is None:
209 decoded = logits.argmax(dim=-1)
210 decoded = decoded.cpu().numpy()
211 attention_mask = attention_mask.cpu().numpy()
212
213 decoded = [
214 [self.cws_vocab[tag] for tag, mask in zip(tags, masks) if mask]
215 for tags, masks in zip(decoded, attention_mask)
216 ]
217 else:
218 logits = torch.log_softmax(logits, dim=-1)
219 decoded = crf.decode(logits, attention_mask)
220 decoded = [[self.cws_vocab[tag] for tag in tags] for tags in decoded]
221 entities = [get_entities([d[i] for i in idx]) for d, idx in zip(decoded, char_idx)]
222 entities = [[(e[1], e[2]) for e in se] for se in entities]
223
224 words = [
225 [sent[e[0] : e[1] + 1] for e in sent_entities]
226 for sent, sent_entities in zip(text, entities)
227 ]
228
229 if len(self.hook):
230 words = [self.hook.hook(t, w) for t, w in zip(text, words)]
231 entities = []
232 for he in [np.cumsum([len(w) for w in s]) for s in words]:
233 entities.append([])
234 for i, e in enumerate(he):
235 if i == 0:
236 entities[-1].append((0, e - 1))

Callers

nothing calls this directly

Calls 2

cpuMethod · 0.80
decodeMethod · 0.80

Tested by

no test coverage detected