MCPcopy
hub / github.com/dmlc/dgl / _scatter_nd

Function _scatter_nd

python/dgl/backend/tensorflow/sparse.py:31–56  ·  view source on GitHub ↗
(index, src, n_rows)

Source from the content-addressed store, hash-verified

29
30
31def _scatter_nd(index, src, n_rows):
32 assert index.shape == src.shape
33 shp = index.shape
34 ctx = context(src)
35 ndim = index.ndim
36 offsets = []
37 stride = 1
38 for i in reversed(range(1, ndim)):
39 di = shp[i]
40 offset_i = tf.range(di, dtype=index.dtype)
41 offsets.append(
42 tf.reshape(
43 (stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i)
44 )
45 )
46 stride *= di
47 if ndim > 1:
48 new_idx = index * stride + copy_to(sum(offsets), ctx)
49 else:
50 new_idx = index
51 src = tf.reshape(src, (-1,))
52 new_idx = tf.reshape(new_idx, (-1, 1))
53 rst = tf.reshape(
54 tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:])
55 )
56 return rst
57
58
59def _gather_nd(index, src):

Callers 1

gradFunction · 0.70

Calls 4

appendMethod · 0.80
contextFunction · 0.70
copy_toFunction · 0.70
sumFunction · 0.70

Tested by

no test coverage detected