(adj, features: torch.FloatTensor, labels: torch.Tensor = None, device="cpu")
| 25 | |
| 26 | |
| 27 | def getGraph(adj, features: torch.FloatTensor, labels: torch.Tensor = None, device="cpu"): |
| 28 | if type(adj) != torch.Tensor: |
| 29 | edge_index, edge_attr = adj2edge(adj, device) |
| 30 | data = Graph(x=features, y=labels, edge_index=edge_index, edge_attr=edge_attr).to(device) |
| 31 | else: |
| 32 | if adj.is_sparse: |
| 33 | adj_np = sp.csr_matrix(adj.to_dense().detach().cpu().numpy()) |
| 34 | else: |
| 35 | adj_np = sp.csr_matrix(adj.detach().cpu().numpy()) |
| 36 | # print(type(adj_np)) |
| 37 | edge_index, edge_attr = adj2edge(adj_np, device) |
| 38 | data = Graph(x=features, y=labels, edge_index=edge_index, edge_attr=edge_attr, grb_adj=adj).to(device) |
| 39 | return data |
| 40 | |
| 41 | |
| 42 | def updateGraph(graph, adj, features: torch.FloatTensor): |
no test coverage detected