Class of node data which is used for DGLGraph construction. Internal use only.
| 126 | |
| 127 | |
| 128 | class NodeData(BaseData): |
| 129 | """Class of node data which is used for DGLGraph construction. Internal use only.""" |
| 130 | |
| 131 | def __init__(self, node_id, data, type=None, graph_id=None): |
| 132 | self.id = np.array(node_id) |
| 133 | self.data = data |
| 134 | self.type = type if type is not None else "_V" |
| 135 | self.graph_id = ( |
| 136 | np.array(graph_id) |
| 137 | if graph_id is not None |
| 138 | else np.full(len(node_id), 0) |
| 139 | ) |
| 140 | _validate_data_length( |
| 141 | {**{"id": self.id, "graph_id": self.graph_id}, **self.data} |
| 142 | ) |
| 143 | |
| 144 | @staticmethod |
| 145 | def load_from_csv( |
| 146 | meta: MetaNode, data_parser: Callable, base_dir=None, separator="," |
| 147 | ): |
| 148 | df = BaseData.read_csv(meta.file_name, base_dir, separator) |
| 149 | node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field) |
| 150 | graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field) |
| 151 | if node_ids is None: |
| 152 | raise DGLError( |
| 153 | "Missing node id field [{}] in file [{}].".format( |
| 154 | meta.node_id_field, meta.file_name |
| 155 | ) |
| 156 | ) |
| 157 | ntype = meta.ntype |
| 158 | ndata = data_parser(df) |
| 159 | return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids) |
| 160 | |
| 161 | @staticmethod |
| 162 | def to_dict(node_data: List["NodeData"]) -> dict: |
| 163 | # node_ids could be numeric or non-numeric values, but duplication is not allowed. |
| 164 | node_dict = {} |
| 165 | for n_data in node_data: |
| 166 | graph_ids = np.unique(n_data.graph_id) |
| 167 | for graph_id in graph_ids: |
| 168 | idx = n_data.graph_id == graph_id |
| 169 | ids = n_data.id[idx] |
| 170 | u_ids, u_indices, u_counts = np.unique( |
| 171 | ids, return_index=True, return_counts=True |
| 172 | ) |
| 173 | if len(ids) > len(u_ids): |
| 174 | raise DGLError( |
| 175 | "Node IDs are required to be unique but the following ids are duplicate: {}".format( |
| 176 | u_ids[u_counts > 1] |
| 177 | ) |
| 178 | ) |
| 179 | if graph_id not in node_dict: |
| 180 | node_dict[graph_id] = {} |
| 181 | node_dict[graph_id][n_data.type] = { |
| 182 | "mapping": { |
| 183 | index: i for i, index in enumerate(ids[u_indices]) |
| 184 | }, |
| 185 | "data": { |
no outgoing calls