MCPcopy
hub / github.com/THUDM/CogDL / updateGraph

Function updateGraph

cogdl/utils/grb_utils.py:42–57  ·  view source on GitHub ↗
(graph, adj, features: torch.FloatTensor)

Source from the content-addressed store, hash-verified

40
41
42def updateGraph(graph, adj, features: torch.FloatTensor):
43 if type(adj) != torch.Tensor:
44 edge_index, edge_attr = adj2edge(adj, graph.device)
45 graph.x = features
46 graph.edge_index = edge_index
47 graph.edge_attr = edge_attr
48 else:
49 if adj.is_sparse:
50 adj_np = sp.csr_matrix(adj.to_dense().detach().cpu().numpy())
51 else:
52 adj_np = sp.csr_matrix(adj.detach().cpu().numpy())
53 edge_index, edge_attr = adj2edge(adj_np, graph.device)
54 graph.x = features
55 graph.edge_index = edge_index
56 graph.edge_attr = edge_attr
57 graph.grb_adj = adj
58
59
60def adj2edge(adj: sp.csr.csr_matrix, device="cpu"):

Callers 2

trainMethod · 0.90
evaluateFunction · 0.85

Calls 1

adj2edgeFunction · 0.70

Tested by

no test coverage detected