Class of graph data which is used for DGLGraph construction. Internal use only.
| 269 | |
| 270 | |
| 271 | class GraphData(BaseData): |
| 272 | """Class of graph data which is used for DGLGraph construction. Internal use only.""" |
| 273 | |
| 274 | def __init__(self, graph_id, data): |
| 275 | self.graph_id = np.array(graph_id) |
| 276 | self.data = data |
| 277 | _validate_data_length({**{"graph_id": self.graph_id}, **self.data}) |
| 278 | |
| 279 | @staticmethod |
| 280 | def load_from_csv( |
| 281 | meta: MetaGraph, data_parser: Callable, base_dir=None, separator="," |
| 282 | ): |
| 283 | df = BaseData.read_csv(meta.file_name, base_dir, separator) |
| 284 | graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field) |
| 285 | if graph_ids is None: |
| 286 | raise DGLError( |
| 287 | "Missing graph id field [{}] in file [{}].".format( |
| 288 | meta.graph_id_field, meta.file_name |
| 289 | ) |
| 290 | ) |
| 291 | gdata = data_parser(df) |
| 292 | return GraphData(graph_ids, gdata) |
| 293 | |
| 294 | @staticmethod |
| 295 | def to_dict(graph_data: "GraphData", graphs_dict: dict) -> dict: |
| 296 | missing_ids = np.setdiff1d( |
| 297 | np.array(list(graphs_dict.keys())), graph_data.graph_id |
| 298 | ) |
| 299 | if len(missing_ids) > 0: |
| 300 | raise DGLError( |
| 301 | "Found following graph ids in node/edge CSVs but not in graph CSV: {}.".format( |
| 302 | missing_ids |
| 303 | ) |
| 304 | ) |
| 305 | graph_ids = graph_data.graph_id |
| 306 | graphs = [] |
| 307 | for graph_id in graph_ids: |
| 308 | if graph_id not in graphs_dict: |
| 309 | graphs_dict[graph_id] = dgl_heterograph( |
| 310 | {("_V", "_E", "_V"): ([], [])} |
| 311 | ) |
| 312 | for graph_id in graph_ids: |
| 313 | graphs.append(graphs_dict[graph_id]) |
| 314 | data = { |
| 315 | k: F.reshape(_tensor(v), (len(graphs), -1)) |
| 316 | for k, v in graph_data.data.items() |
| 317 | } |
| 318 | return graphs, data |
| 319 | |
| 320 | |
| 321 | class DGLGraphConstructor: |
no outgoing calls