MCPcopy
hub / github.com/HIT-SCIR/ltp / get_graph_entities

Method get_graph_entities

python/core/ltp_core/models/metrics/graph.py:94–110  ·  view source on GitHub ↗
(arcs, rels, flatten=False)

Source from the content-addressed store, hash-verified

92
93 @staticmethod
94 def get_graph_entities(arcs, rels, flatten=False):
95 sequence_num = rels.shape[0]
96 arcs = torch.nonzero(arcs, as_tuple=False).cpu().detach().numpy().tolist()
97 rels = rels.cpu().detach().numpy()
98
99 if flatten:
100 res = []
101 else:
102 res = [[] for _ in range(sequence_num)]
103 for idx, arc_s, arc_e in arcs:
104 label = rels[idx, arc_s, arc_e]
105 if flatten:
106 res.append((idx, arc_s, arc_e, label))
107 else:
108 res[idx].append((arc_s, arc_e, label))
109
110 return res

Callers 1

updateMethod · 0.95

Calls 1

cpuMethod · 0.80

Tested by

no test coverage detected