MCPcopy
hub / github.com/dmlc/dgl / _get_subgraph_batch_info

Function _get_subgraph_batch_info

tests/python/common/test_batch-graph.py:239–270  ·  view source on GitHub ↗

Internal function to compute batch information for subgraphs. Parameters ---------- keys : List[str] The node/edge type keys. induced_indices_arr : List[Tensor] The induced node/edge index tensor for all node/edge types. batch_num_objs : Tensor Number of n

(keys, induced_indices_arr, batch_num_objs)

Source from the content-addressed store, hash-verified

237
238
239def _get_subgraph_batch_info(keys, induced_indices_arr, batch_num_objs):
240 """Internal function to compute batch information for subgraphs.
241 Parameters
242 ----------
243 keys : List[str]
244 The node/edge type keys.
245 induced_indices_arr : List[Tensor]
246 The induced node/edge index tensor for all node/edge types.
247 batch_num_objs : Tensor
248 Number of nodes/edges for each graph in the original batch.
249 Returns
250 -------
251 Mapping[str, Tensor]
252 A dictionary mapping all node/edge type keys to the ``batch_num_objs``
253 array of corresponding graph.
254 """
255 bucket_offset = np.expand_dims(
256 np.cumsum(F.asnumpy(batch_num_objs), 0), -1
257 ) # (num_bkts, 1)
258 ret = {}
259 for key, induced_indices in zip(keys, induced_indices_arr):
260 # NOTE(Zihao): this implementation is not efficient and we can replace it with
261 # binary search in the future.
262 induced_indices = np.expand_dims(
263 F.asnumpy(induced_indices), 0
264 ) # (1, num_nodes)
265 new_offset = np.sum((induced_indices < bucket_offset), 1) # (num_bkts,)
266 # start_offset = [0] + [new_offset[i-1] for i in range(1, n_bkts)]
267 start_offset = np.concatenate([np.zeros((1,)), new_offset[:-1]], 0)
268 new_batch_num_objs = new_offset - start_offset
269 ret[key] = F.tensor(new_batch_num_objs, dtype=F.dtype(batch_num_objs))
270 return ret
271
272
273@parametrize_idtype

Callers 1

test_set_batch_infoFunction · 0.85

Calls 2

asnumpyMethod · 0.80
dtypeMethod · 0.45

Tested by

no test coverage detected