MCPcopy
hub / github.com/dmlc/dgl / compute_blocks

Method compute_blocks

python/dgl/graphbolt/minibatch.py:183–304  ·  view source on GitHub ↗

Extracts DGL blocks from `MiniBatch` to construct graphical structures and ID mappings.

(self)

Source from the content-addressed store, hash-verified

181 return self._blocks
182
183 def compute_blocks(self) -> list:
184 """Extracts DGL blocks from `MiniBatch` to construct graphical
185 structures and ID mappings.
186 """
187 from dgl.convert import create_block, EID, NID
188
189 is_heterogeneous = isinstance(
190 self.sampled_subgraphs[0].sampled_csc, Dict
191 )
192
193 # Casts to minimum dtype in-place and returns self.
194 def cast_to_minimum_dtype(v: CSCFormatBase):
195 # Checks if number of vertices and edges fit into an int32.
196 dtype = (
197 torch.int32
198 if max(v.indptr.size(0) - 2, v.indices.size(0))
199 <= torch.iinfo(torch.int32).max
200 else torch.int64
201 )
202 v.indptr = v.indptr.to(dtype)
203 v.indices = v.indices.to(dtype)
204 return v
205
206 blocks = []
207 for subgraph in self.sampled_subgraphs:
208 original_row_node_ids = subgraph.original_row_node_ids
209 assert (
210 original_row_node_ids is not None
211 ), "Missing `original_row_node_ids` in sampled subgraph."
212 original_column_node_ids = subgraph.original_column_node_ids
213 assert (
214 original_column_node_ids is not None
215 ), "Missing `original_column_node_ids` in sampled subgraph."
216 if is_heterogeneous:
217 node_types = set()
218 sampled_csc = {}
219 for v in subgraph.sampled_csc.values():
220 cast_to_minimum_dtype(v)
221 for etype, v in subgraph.sampled_csc.items():
222 etype_tuple = etype_str_to_tuple(etype)
223 node_types.add(etype_tuple[0])
224 node_types.add(etype_tuple[2])
225 sampled_csc[etype_tuple] = (
226 "csc",
227 (
228 v.indptr,
229 v.indices,
230 torch.arange(
231 0,
232 len(v.indices),
233 device=v.indptr.device,
234 dtype=v.indptr.dtype,
235 ),
236 ),
237 )
238 num_src_nodes = {
239 ntype: (
240 original_row_node_ids[ntype].size(0)

Callers 1

blocksMethod · 0.95

Calls 7

create_blockFunction · 0.90
etype_str_to_tupleFunction · 0.85
appendMethod · 0.80
valuesMethod · 0.45
itemsMethod · 0.45
sizeMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected