MCPcopy
hub / github.com/dmlc/dgl / _scatter_nd

Function _scatter_nd

python/dgl/backend/mxnet/sparse.py:38–68  ·  view source on GitHub ↗

Similar to PyTorch's scatter nd on first dimension.

(index, src, n_rows)

Source from the content-addressed store, hash-verified

36
37
38def _scatter_nd(index, src, n_rows):
39 """Similar to PyTorch's scatter nd on first dimension."""
40 assert index.shape == src.shape
41 dgl_warning("MXNet do not support scatter_add, fallback to numpy.")
42 ctx = context(src)
43 index = asnumpy(index)
44 src = asnumpy(src)
45 shp = index.shape
46 ndim = src.ndim
47 offsets = []
48 stride = 1
49 for i in reversed(range(1, ndim)):
50 di = shp[i]
51 offset_i = np.arange(di, dtype=index.dtype)
52 offsets.append(
53 (stride * offset_i).reshape(
54 (1,) * i + (di,) + (1,) * (ndim - 1 - i)
55 )
56 )
57 stride *= di
58 if ndim > 1:
59 new_idx = index * stride + sum(offsets)
60 else:
61 new_idx = index
62 src = src.reshape(-1)
63 new_idx = new_idx.reshape(-1)
64 rst = np.zeros((stride * n_rows,), dtype=src.dtype)
65 np.add.at(rst, new_idx, src)
66 rst = rst.reshape(n_rows, *shp[1:])
67 rst = copy_to(zerocopy_from_numpy(rst), ctx)
68 return rst
69
70
71def _gather_nd(index, src):

Callers 1

backwardMethod · 0.70

Calls 7

dgl_warningFunction · 0.85
appendMethod · 0.80
contextFunction · 0.70
asnumpyFunction · 0.70
sumFunction · 0.70
copy_toFunction · 0.70
zerocopy_from_numpyFunction · 0.70

Tested by

no test coverage detected