MCPcopy
hub / github.com/HIT-SCIR/ltp / _dep_post

Method _dep_post

python/interface/ltp/nerual.py:361–388  ·  view source on GitHub ↗
(
        self,
        result: GraphResult,
        hidden: Dict[str, torch.Tensor],
        store: Dict[str, Any],
        inputs: List[str] = None,
        tokenized: BatchEncoding = None,
    )

Source from the content-addressed store, hash-verified

359
360 @no_grad
361 def _dep_post(
362 self,
363 result: GraphResult,
364 hidden: Dict[str, torch.Tensor],
365 store: Dict[str, Any],
366 inputs: List[str] = None,
367 tokenized: BatchEncoding = None,
368 ) -> LTPOutput:
369 from ltp_core.models.components.token import BiaffineTokenClassifier
370
371 s_arc = result.arc_logits
372 s_rel = result.rel_logits
373 attention_mask = result.attention_mask
374
375 # mask root 和 对角线部分
376 s_arc[:, 0, 1:] = float("-inf")
377 s_arc.diagonal(0, 1, 2).fill_(float("-inf"))
378
379 s_arc = s_arc.view(-1).cpu().numpy()
380 length = torch.sum(attention_mask, dim=1).view(-1).cpu().numpy() + 1
381 arcs = [sequence for sequence in eisner(s_arc, length, True)]
382 rels = torch.argmax(s_rel[:, 1:], dim=-1).cpu().numpy()
383 rels = [
384 [self.dep_vocab[rels[s, t, a]] for t, a in enumerate(arc)]
385 for s, arc in enumerate(arcs)
386 ]
387
388 return [{"head": arc, "label": rel} for arc, rel in zip(arcs, rels)]
389
390 @no_grad
391 def _sdp_post(

Callers

nothing calls this directly

Calls 2

eisnerFunction · 0.85
cpuMethod · 0.80

Tested by

no test coverage detected