| 29 | |
| 30 | |
| 31 | def _scatter_nd(index, src, n_rows): |
| 32 | assert index.shape == src.shape |
| 33 | shp = index.shape |
| 34 | ctx = context(src) |
| 35 | ndim = index.ndim |
| 36 | offsets = [] |
| 37 | stride = 1 |
| 38 | for i in reversed(range(1, ndim)): |
| 39 | di = shp[i] |
| 40 | offset_i = tf.range(di, dtype=index.dtype) |
| 41 | offsets.append( |
| 42 | tf.reshape( |
| 43 | (stride * offset_i), (1,) * i + (di,) + (1,) * (ndim - 1 - i) |
| 44 | ) |
| 45 | ) |
| 46 | stride *= di |
| 47 | if ndim > 1: |
| 48 | new_idx = index * stride + copy_to(sum(offsets), ctx) |
| 49 | else: |
| 50 | new_idx = index |
| 51 | src = tf.reshape(src, (-1,)) |
| 52 | new_idx = tf.reshape(new_idx, (-1, 1)) |
| 53 | rst = tf.reshape( |
| 54 | tf.scatter_nd(new_idx, src, (stride * n_rows,)), (n_rows, *shp[1:]) |
| 55 | ) |
| 56 | return rst |
| 57 | |
| 58 | |
| 59 | def _gather_nd(index, src): |