MCPcopy
hub / github.com/dmlc/dgl / NodeData

Class NodeData

python/dgl/data/csv_dataset_base.py:128–191  ·  view source on GitHub ↗

Class of node data which is used for DGLGraph construction. Internal use only.

Source from the content-addressed store, hash-verified

126
127
128class NodeData(BaseData):
129 """Class of node data which is used for DGLGraph construction. Internal use only."""
130
131 def __init__(self, node_id, data, type=None, graph_id=None):
132 self.id = np.array(node_id)
133 self.data = data
134 self.type = type if type is not None else "_V"
135 self.graph_id = (
136 np.array(graph_id)
137 if graph_id is not None
138 else np.full(len(node_id), 0)
139 )
140 _validate_data_length(
141 {**{"id": self.id, "graph_id": self.graph_id}, **self.data}
142 )
143
144 @staticmethod
145 def load_from_csv(
146 meta: MetaNode, data_parser: Callable, base_dir=None, separator=","
147 ):
148 df = BaseData.read_csv(meta.file_name, base_dir, separator)
149 node_ids = BaseData.pop_from_dataframe(df, meta.node_id_field)
150 graph_ids = BaseData.pop_from_dataframe(df, meta.graph_id_field)
151 if node_ids is None:
152 raise DGLError(
153 "Missing node id field [{}] in file [{}].".format(
154 meta.node_id_field, meta.file_name
155 )
156 )
157 ntype = meta.ntype
158 ndata = data_parser(df)
159 return NodeData(node_ids, ndata, type=ntype, graph_id=graph_ids)
160
161 @staticmethod
162 def to_dict(node_data: List["NodeData"]) -> dict:
163 # node_ids could be numeric or non-numeric values, but duplication is not allowed.
164 node_dict = {}
165 for n_data in node_data:
166 graph_ids = np.unique(n_data.graph_id)
167 for graph_id in graph_ids:
168 idx = n_data.graph_id == graph_id
169 ids = n_data.id[idx]
170 u_ids, u_indices, u_counts = np.unique(
171 ids, return_index=True, return_counts=True
172 )
173 if len(ids) > len(u_ids):
174 raise DGLError(
175 "Node IDs are required to be unique but the following ids are duplicate: {}".format(
176 u_ids[u_counts > 1]
177 )
178 )
179 if graph_id not in node_dict:
180 node_dict[graph_id] = {}
181 node_dict[graph_id][n_data.type] = {
182 "mapping": {
183 index: i for i, index in enumerate(ids[u_indices])
184 },
185 "data": {

Callers 6

_test_NodeEdgeGraphDataFunction · 0.90
load_from_csvMethod · 0.85

Calls

no outgoing calls