Function to convert a networkx graph to edge tensors. Parameters ---------- nx_graph : nx.Graph NetworkX graph. idtype : int32, int64, optional Integer ID type. Must be int32 or int64. edge_id_attr_name : str, optional Key name for edge ids in the Network
(nx_graph, idtype, edge_id_attr_name=None)
| 65 | |
| 66 | |
| 67 | def networkx2tensor(nx_graph, idtype, edge_id_attr_name=None): |
| 68 | """Function to convert a networkx graph to edge tensors. |
| 69 | |
| 70 | Parameters |
| 71 | ---------- |
| 72 | nx_graph : nx.Graph |
| 73 | NetworkX graph. |
| 74 | idtype : int32, int64, optional |
| 75 | Integer ID type. Must be int32 or int64. |
| 76 | edge_id_attr_name : str, optional |
| 77 | Key name for edge ids in the NetworkX graph. If not found, we |
| 78 | will consider the graph not to have pre-specified edge ids. (Default: None) |
| 79 | |
| 80 | Returns |
| 81 | ------- |
| 82 | (Tensor, Tensor) |
| 83 | Edge tensors. |
| 84 | """ |
| 85 | if not nx_graph.is_directed(): |
| 86 | nx_graph = nx_graph.to_directed() |
| 87 | |
| 88 | # Relabel nodes using consecutive integers |
| 89 | nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering="sorted") |
| 90 | has_edge_id = edge_id_attr_name is not None |
| 91 | |
| 92 | if has_edge_id: |
| 93 | num_edges = nx_graph.number_of_edges() |
| 94 | src = [0] * num_edges |
| 95 | dst = [0] * num_edges |
| 96 | for u, v, attr in nx_graph.edges(data=True): |
| 97 | eid = int(attr[edge_id_attr_name]) |
| 98 | if eid < 0 or eid >= nx_graph.number_of_edges(): |
| 99 | raise DGLError( |
| 100 | "Expect edge IDs to be a non-negative integer smaller than {:d}, " |
| 101 | "got {:d}".format(num_edges, eid) |
| 102 | ) |
| 103 | src[eid] = u |
| 104 | dst[eid] = v |
| 105 | else: |
| 106 | src = [] |
| 107 | dst = [] |
| 108 | for e in nx_graph.edges: |
| 109 | src.append(e[0]) |
| 110 | dst.append(e[1]) |
| 111 | src = F.tensor(src, idtype) |
| 112 | dst = F.tensor(dst, idtype) |
| 113 | return src, dst |
| 114 | |
| 115 | |
| 116 | SparseAdjTuple = namedtuple("SparseAdjTuple", ["format", "arrays"]) |
no test coverage detected