Extracts DGL blocks from `MiniBatch` to construct graphical structures and ID mappings.
(self)
| 181 | return self._blocks |
| 182 | |
| 183 | def compute_blocks(self) -> list: |
| 184 | """Extracts DGL blocks from `MiniBatch` to construct graphical |
| 185 | structures and ID mappings. |
| 186 | """ |
| 187 | from dgl.convert import create_block, EID, NID |
| 188 | |
| 189 | is_heterogeneous = isinstance( |
| 190 | self.sampled_subgraphs[0].sampled_csc, Dict |
| 191 | ) |
| 192 | |
| 193 | # Casts to minimum dtype in-place and returns self. |
| 194 | def cast_to_minimum_dtype(v: CSCFormatBase): |
| 195 | # Checks if number of vertices and edges fit into an int32. |
| 196 | dtype = ( |
| 197 | torch.int32 |
| 198 | if max(v.indptr.size(0) - 2, v.indices.size(0)) |
| 199 | <= torch.iinfo(torch.int32).max |
| 200 | else torch.int64 |
| 201 | ) |
| 202 | v.indptr = v.indptr.to(dtype) |
| 203 | v.indices = v.indices.to(dtype) |
| 204 | return v |
| 205 | |
| 206 | blocks = [] |
| 207 | for subgraph in self.sampled_subgraphs: |
| 208 | original_row_node_ids = subgraph.original_row_node_ids |
| 209 | assert ( |
| 210 | original_row_node_ids is not None |
| 211 | ), "Missing `original_row_node_ids` in sampled subgraph." |
| 212 | original_column_node_ids = subgraph.original_column_node_ids |
| 213 | assert ( |
| 214 | original_column_node_ids is not None |
| 215 | ), "Missing `original_column_node_ids` in sampled subgraph." |
| 216 | if is_heterogeneous: |
| 217 | node_types = set() |
| 218 | sampled_csc = {} |
| 219 | for v in subgraph.sampled_csc.values(): |
| 220 | cast_to_minimum_dtype(v) |
| 221 | for etype, v in subgraph.sampled_csc.items(): |
| 222 | etype_tuple = etype_str_to_tuple(etype) |
| 223 | node_types.add(etype_tuple[0]) |
| 224 | node_types.add(etype_tuple[2]) |
| 225 | sampled_csc[etype_tuple] = ( |
| 226 | "csc", |
| 227 | ( |
| 228 | v.indptr, |
| 229 | v.indices, |
| 230 | torch.arange( |
| 231 | 0, |
| 232 | len(v.indices), |
| 233 | device=v.indptr.device, |
| 234 | dtype=v.indptr.dtype, |
| 235 | ), |
| 236 | ), |
| 237 | ) |
| 238 | num_src_nodes = { |
| 239 | ntype: ( |
| 240 | original_row_node_ids[ntype].size(0) |
no test coverage detected