| 40 | |
| 41 | |
| 42 | def updateGraph(graph, adj, features: torch.FloatTensor): |
| 43 | if type(adj) != torch.Tensor: |
| 44 | edge_index, edge_attr = adj2edge(adj, graph.device) |
| 45 | graph.x = features |
| 46 | graph.edge_index = edge_index |
| 47 | graph.edge_attr = edge_attr |
| 48 | else: |
| 49 | if adj.is_sparse: |
| 50 | adj_np = sp.csr_matrix(adj.to_dense().detach().cpu().numpy()) |
| 51 | else: |
| 52 | adj_np = sp.csr_matrix(adj.detach().cpu().numpy()) |
| 53 | edge_index, edge_attr = adj2edge(adj_np, graph.device) |
| 54 | graph.x = features |
| 55 | graph.edge_index = edge_index |
| 56 | graph.edge_attr = edge_attr |
| 57 | graph.grb_adj = adj |
| 58 | |
| 59 | |
| 60 | def adj2edge(adj: sp.csr.csr_matrix, device="cpu"): |