(arcs, rels, flatten=False)
| 92 | |
| 93 | @staticmethod |
| 94 | def get_graph_entities(arcs, rels, flatten=False): |
| 95 | sequence_num = rels.shape[0] |
| 96 | arcs = torch.nonzero(arcs, as_tuple=False).cpu().detach().numpy().tolist() |
| 97 | rels = rels.cpu().detach().numpy() |
| 98 | |
| 99 | if flatten: |
| 100 | res = [] |
| 101 | else: |
| 102 | res = [[] for _ in range(sequence_num)] |
| 103 | for idx, arc_s, arc_e in arcs: |
| 104 | label = rels[idx, arc_s, arc_e] |
| 105 | if flatten: |
| 106 | res.append((idx, arc_s, arc_e, label)) |
| 107 | else: |
| 108 | res[idx].append((arc_s, arc_e, label)) |
| 109 | |
| 110 | return res |