Visualizes a heterogeneous graph using networkx.
(
edge_index_dict: Dict[Tuple[str, str, str], Tensor],
edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
path: Optional[str] = None,
backend: Optional[str] = None,
node_labels_dict: Optional[Dict[str, List[str]]] = None,
node_weight_dict: Optional[Dict[str, Tensor]] = None,
node_size_range: Tuple[float, float] = (50, 500),
node_opacity_range: Tuple[float, float] = (1.0, 1.0),
edge_width_range: Tuple[float, float] = (0.1, 2.0),
edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
)
| 153 | |
| 154 | |
| 155 | def visualize_hetero_graph( |
| 156 | edge_index_dict: Dict[Tuple[str, str, str], Tensor], |
| 157 | edge_weight_dict: Dict[Tuple[str, str, str], Tensor], |
| 158 | path: Optional[str] = None, |
| 159 | backend: Optional[str] = None, |
| 160 | node_labels_dict: Optional[Dict[str, List[str]]] = None, |
| 161 | node_weight_dict: Optional[Dict[str, Tensor]] = None, |
| 162 | node_size_range: Tuple[float, float] = (50, 500), |
| 163 | node_opacity_range: Tuple[float, float] = (1.0, 1.0), |
| 164 | edge_width_range: Tuple[float, float] = (0.1, 2.0), |
| 165 | edge_opacity_range: Tuple[float, float] = (1.0, 1.0), |
| 166 | ) -> Any: |
| 167 | """Visualizes a heterogeneous graph using networkx.""" |
| 168 | if backend is not None and backend != "networkx": |
| 169 | raise ValueError("Only 'networkx' backend is supported") |
| 170 | |
| 171 | # Filter out edges with 0 weight |
| 172 | filtered_edge_index_dict = {} |
| 173 | filtered_edge_weight_dict = {} |
| 174 | for edge_type in edge_index_dict.keys(): |
| 175 | mask = edge_weight_dict[edge_type] > 0 |
| 176 | if mask.sum() > 0: |
| 177 | filtered_edge_index_dict[edge_type] = edge_index_dict[ |
| 178 | edge_type][:, mask] |
| 179 | filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][ |
| 180 | mask] |
| 181 | |
| 182 | # Get all unique nodes that are still in the filtered edges |
| 183 | remaining_nodes: Dict[str, Set[int]] = {} |
| 184 | for edge_type, edge_index in filtered_edge_index_dict.items(): |
| 185 | src_type, _, dst_type = edge_type |
| 186 | if src_type not in remaining_nodes: |
| 187 | remaining_nodes[src_type] = set() |
| 188 | if dst_type not in remaining_nodes: |
| 189 | remaining_nodes[dst_type] = set() |
| 190 | remaining_nodes[src_type].update(edge_index[0].tolist()) |
| 191 | remaining_nodes[dst_type].update(edge_index[1].tolist()) |
| 192 | |
| 193 | # Filter node weights to only include remaining nodes |
| 194 | if node_weight_dict is not None: |
| 195 | filtered_node_weight_dict = {} |
| 196 | for node_type, weights in node_weight_dict.items(): |
| 197 | if node_type in remaining_nodes: |
| 198 | mask = torch.zeros(len(weights), dtype=torch.bool) |
| 199 | mask[list(remaining_nodes[node_type])] = True |
| 200 | filtered_node_weight_dict[node_type] = weights[mask] |
| 201 | node_weight_dict = filtered_node_weight_dict |
| 202 | |
| 203 | # Filter node labels to only include remaining nodes |
| 204 | if node_labels_dict is not None: |
| 205 | filtered_node_labels_dict = {} |
| 206 | for node_type, labels in node_labels_dict.items(): |
| 207 | if node_type in remaining_nodes: |
| 208 | filtered_node_labels_dict[node_type] = [ |
| 209 | label for i, label in enumerate(labels) |
| 210 | if i in remaining_nodes[node_type] |
| 211 | ] |
| 212 | node_labels_dict = filtered_node_labels_dict |
no test coverage detected