Similar to PyTorch's scatter nd on first dimension.
(index, src, n_rows)
| 36 | |
| 37 | |
| 38 | def _scatter_nd(index, src, n_rows): |
| 39 | """Similar to PyTorch's scatter nd on first dimension.""" |
| 40 | assert index.shape == src.shape |
| 41 | dgl_warning("MXNet do not support scatter_add, fallback to numpy.") |
| 42 | ctx = context(src) |
| 43 | index = asnumpy(index) |
| 44 | src = asnumpy(src) |
| 45 | shp = index.shape |
| 46 | ndim = src.ndim |
| 47 | offsets = [] |
| 48 | stride = 1 |
| 49 | for i in reversed(range(1, ndim)): |
| 50 | di = shp[i] |
| 51 | offset_i = np.arange(di, dtype=index.dtype) |
| 52 | offsets.append( |
| 53 | (stride * offset_i).reshape( |
| 54 | (1,) * i + (di,) + (1,) * (ndim - 1 - i) |
| 55 | ) |
| 56 | ) |
| 57 | stride *= di |
| 58 | if ndim > 1: |
| 59 | new_idx = index * stride + sum(offsets) |
| 60 | else: |
| 61 | new_idx = index |
| 62 | src = src.reshape(-1) |
| 63 | new_idx = new_idx.reshape(-1) |
| 64 | rst = np.zeros((stride * n_rows,), dtype=src.dtype) |
| 65 | np.add.at(rst, new_idx, src) |
| 66 | rst = rst.reshape(n_rows, *shp[1:]) |
| 67 | rst = copy_to(zerocopy_from_numpy(rst), ctx) |
| 68 | return rst |
| 69 | |
| 70 | |
| 71 | def _gather_nd(index, src): |
no test coverage detected