MCPcopy
hub / github.com/pyg-team/pytorch_geometric / visualize_hetero_graph

Function visualize_hetero_graph

torch_geometric/visualization/graph.py:155–224  ·  view source on GitHub ↗

Visualizes a heterogeneous graph using networkx.

(
        edge_index_dict: Dict[Tuple[str, str, str], Tensor],
        edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
        path: Optional[str] = None,
        backend: Optional[str] = None,
        node_labels_dict: Optional[Dict[str, List[str]]] = None,
        node_weight_dict: Optional[Dict[str, Tensor]] = None,
        node_size_range: Tuple[float, float] = (50, 500),
        node_opacity_range: Tuple[float, float] = (1.0, 1.0),
        edge_width_range: Tuple[float, float] = (0.1, 2.0),
        edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
)

Source from the content-addressed store, hash-verified

153
154
155def visualize_hetero_graph(
156 edge_index_dict: Dict[Tuple[str, str, str], Tensor],
157 edge_weight_dict: Dict[Tuple[str, str, str], Tensor],
158 path: Optional[str] = None,
159 backend: Optional[str] = None,
160 node_labels_dict: Optional[Dict[str, List[str]]] = None,
161 node_weight_dict: Optional[Dict[str, Tensor]] = None,
162 node_size_range: Tuple[float, float] = (50, 500),
163 node_opacity_range: Tuple[float, float] = (1.0, 1.0),
164 edge_width_range: Tuple[float, float] = (0.1, 2.0),
165 edge_opacity_range: Tuple[float, float] = (1.0, 1.0),
166) -> Any:
167 """Visualizes a heterogeneous graph using networkx."""
168 if backend is not None and backend != "networkx":
169 raise ValueError("Only 'networkx' backend is supported")
170
171 # Filter out edges with 0 weight
172 filtered_edge_index_dict = {}
173 filtered_edge_weight_dict = {}
174 for edge_type in edge_index_dict.keys():
175 mask = edge_weight_dict[edge_type] > 0
176 if mask.sum() > 0:
177 filtered_edge_index_dict[edge_type] = edge_index_dict[
178 edge_type][:, mask]
179 filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][
180 mask]
181
182 # Get all unique nodes that are still in the filtered edges
183 remaining_nodes: Dict[str, Set[int]] = {}
184 for edge_type, edge_index in filtered_edge_index_dict.items():
185 src_type, _, dst_type = edge_type
186 if src_type not in remaining_nodes:
187 remaining_nodes[src_type] = set()
188 if dst_type not in remaining_nodes:
189 remaining_nodes[dst_type] = set()
190 remaining_nodes[src_type].update(edge_index[0].tolist())
191 remaining_nodes[dst_type].update(edge_index[1].tolist())
192
193 # Filter node weights to only include remaining nodes
194 if node_weight_dict is not None:
195 filtered_node_weight_dict = {}
196 for node_type, weights in node_weight_dict.items():
197 if node_type in remaining_nodes:
198 mask = torch.zeros(len(weights), dtype=torch.bool)
199 mask[list(remaining_nodes[node_type])] = True
200 filtered_node_weight_dict[node_type] = weights[mask]
201 node_weight_dict = filtered_node_weight_dict
202
203 # Filter node labels to only include remaining nodes
204 if node_labels_dict is not None:
205 filtered_node_labels_dict = {}
206 for node_type, labels in node_labels_dict.items():
207 if node_type in remaining_nodes:
208 filtered_node_labels_dict[node_type] = [
209 label for i, label in enumerate(labels)
210 if i in remaining_nodes[node_type]
211 ]
212 node_labels_dict = filtered_node_labels_dict

Callers 1

visualize_graphMethod · 0.90

Calls 6

sumMethod · 0.80
keysMethod · 0.45
itemsMethod · 0.45
updateMethod · 0.45
tolistMethod · 0.45

Tested by

no test coverage detected