| 359 | |
| 360 | @no_grad |
| 361 | def _dep_post( |
| 362 | self, |
| 363 | result: GraphResult, |
| 364 | hidden: Dict[str, torch.Tensor], |
| 365 | store: Dict[str, Any], |
| 366 | inputs: List[str] = None, |
| 367 | tokenized: BatchEncoding = None, |
| 368 | ) -> LTPOutput: |
| 369 | from ltp_core.models.components.token import BiaffineTokenClassifier |
| 370 | |
| 371 | s_arc = result.arc_logits |
| 372 | s_rel = result.rel_logits |
| 373 | attention_mask = result.attention_mask |
| 374 | |
| 375 | # mask root 和 对角线部分 |
| 376 | s_arc[:, 0, 1:] = float("-inf") |
| 377 | s_arc.diagonal(0, 1, 2).fill_(float("-inf")) |
| 378 | |
| 379 | s_arc = s_arc.view(-1).cpu().numpy() |
| 380 | length = torch.sum(attention_mask, dim=1).view(-1).cpu().numpy() + 1 |
| 381 | arcs = [sequence for sequence in eisner(s_arc, length, True)] |
| 382 | rels = torch.argmax(s_rel[:, 1:], dim=-1).cpu().numpy() |
| 383 | rels = [ |
| 384 | [self.dep_vocab[rels[s, t, a]] for t, a in enumerate(arc)] |
| 385 | for s, arc in enumerate(arcs) |
| 386 | ] |
| 387 | |
| 388 | return [{"head": arc, "label": rel} for arc, rel in zip(arcs, rels)] |
| 389 | |
| 390 | @no_grad |
| 391 | def _sdp_post( |