(args: {
inputs: TensorScatterUpdateInputs,
backend: MathBackendCPU,
attrs: TensorScatterUpdateAttrs
})
| 22 | import {scatterImpl} from './Scatter_impl'; |
| 23 | |
| 24 | export function tensorScatterUpdate(args: { |
| 25 | inputs: TensorScatterUpdateInputs, |
| 26 | backend: MathBackendCPU, |
| 27 | attrs: TensorScatterUpdateAttrs |
| 28 | }): TensorInfo { |
| 29 | const {inputs, backend} = args; |
| 30 | const {tensor, indices, updates} = inputs; |
| 31 | |
| 32 | const {sliceRank, numUpdates, sliceSize, strides, outputSize} = |
| 33 | backend_util.calculateShapes(updates, indices, tensor.shape); |
| 34 | const sumDupeIndices = false; |
| 35 | |
| 36 | const indicesBuf = backend.bufferSync<Rank, 'int32'>(indices); |
| 37 | const updatesBuf = backend.bufferSync<Rank, 'int32'|'float32'>(updates); |
| 38 | const tensorBuf = backend.bufferSync<Rank, 'int32'|'float32'>(tensor); |
| 39 | const outBuf = scatterImpl( |
| 40 | indicesBuf, updatesBuf, tensor.shape, outputSize, sliceSize, numUpdates, |
| 41 | sliceRank, strides, tensorBuf, sumDupeIndices); |
| 42 | return backend.makeTensorInfo(tensor.shape, outBuf.dtype, outBuf.values); |
| 43 | } |
| 44 | |
| 45 | export const tensorScatterUpdateConfig: KernelConfig = { |
| 46 | kernelName: TensorScatterUpdate, |
no test coverage detected
searching dependent graphs…