MCPcopy
hub / github.com/THUDM/CogDL / getGraph

Function getGraph

cogdl/utils/grb_utils.py:27–39  ·  view source on GitHub ↗
(adj, features: torch.FloatTensor, labels: torch.Tensor = None, device="cpu")

Source from the content-addressed store, hash-verified

25
26
27def getGraph(adj, features: torch.FloatTensor, labels: torch.Tensor = None, device="cpu"):
28 if type(adj) != torch.Tensor:
29 edge_index, edge_attr = adj2edge(adj, device)
30 data = Graph(x=features, y=labels, edge_index=edge_index, edge_attr=edge_attr).to(device)
31 else:
32 if adj.is_sparse:
33 adj_np = sp.csr_matrix(adj.to_dense().detach().cpu().numpy())
34 else:
35 adj_np = sp.csr_matrix(adj.detach().cpu().numpy())
36 # print(type(adj_np))
37 edge_index, edge_attr = adj2edge(adj_np, device)
38 data = Graph(x=features, y=labels, edge_index=edge_index, edge_attr=edge_attr, grb_adj=adj).to(device)
39 return data
40
41
42def updateGraph(graph, adj, features: torch.FloatTensor):

Callers 15

attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
update_featuresMethod · 0.90
attackMethod · 0.90
attackMethod · 0.90

Calls 3

GraphClass · 0.90
adj2edgeFunction · 0.70
toMethod · 0.45

Tested by

no test coverage detected