(rarcs, rels, labels)
| 448 | |
| 449 | @staticmethod |
| 450 | def get_graph_entities(rarcs, rels, labels): |
| 451 | sequence_num = rels.shape[0] |
| 452 | arcs = torch.nonzero(rarcs, as_tuple=False).cpu().detach().numpy().tolist() |
| 453 | rels = rels.cpu().detach().numpy() |
| 454 | |
| 455 | res = [[] for _ in range(sequence_num)] |
| 456 | for idx, arc_s, arc_e in arcs: |
| 457 | label = labels[rels[idx, arc_s, arc_e]] |
| 458 | res[idx].append((arc_s, arc_e, label)) |
| 459 | |
| 460 | return res |
| 461 | |
| 462 | @classmethod |
| 463 | def _from_pretrained( |