(args: {
inputs: SparseToDenseInputs,
backend: MathBackendCPU,
attrs: SparseToDenseAttrs
})
| 21 | import {scatterImpl} from './Scatter_impl'; |
| 22 | |
| 23 | export function sparseToDense(args: { |
| 24 | inputs: SparseToDenseInputs, |
| 25 | backend: MathBackendCPU, |
| 26 | attrs: SparseToDenseAttrs |
| 27 | }): TensorInfo { |
| 28 | const {inputs, backend, attrs} = args; |
| 29 | const {sparseIndices, sparseValues, defaultValue} = inputs; |
| 30 | const {outputShape} = attrs; |
| 31 | |
| 32 | const {sliceRank, numUpdates, sliceSize, strides, outputSize} = |
| 33 | backend_util.calculateShapes(sparseValues, sparseIndices, outputShape); |
| 34 | const sumDupeIndices = false; |
| 35 | |
| 36 | const indicesBuf = backend.bufferSync<Rank, 'int32'>(sparseIndices); |
| 37 | |
| 38 | let outBuf; |
| 39 | switch (sparseValues.dtype) { |
| 40 | case 'bool': { |
| 41 | const updatesBuf = backend.bufferSync<Rank, 'bool'>(sparseValues); |
| 42 | const $defaultValue = |
| 43 | Boolean(backend.data.get(defaultValue.dataId).values[0]); |
| 44 | outBuf = scatterImpl( |
| 45 | indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, |
| 46 | numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices); |
| 47 | break; |
| 48 | } |
| 49 | case 'float32': { |
| 50 | const updatesBuf = backend.bufferSync<Rank, 'float32'>(sparseValues); |
| 51 | const $defaultValue = |
| 52 | backend.data.get(defaultValue.dataId).values[0] as number; |
| 53 | outBuf = scatterImpl( |
| 54 | indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, |
| 55 | numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices); |
| 56 | break; |
| 57 | } |
| 58 | case 'int32': { |
| 59 | const updatesBuf = backend.bufferSync<Rank, 'int32'>(sparseValues); |
| 60 | const $defaultValue = |
| 61 | backend.data.get(defaultValue.dataId).values[0] as number; |
| 62 | outBuf = scatterImpl( |
| 63 | indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, |
| 64 | numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices); |
| 65 | break; |
| 66 | } |
| 67 | case 'string': { |
| 68 | const updatesBuf = backend.bufferSync<Rank, 'string'>(sparseValues); |
| 69 | const $defaultValue = util.decodeString( |
| 70 | backend.data.get(defaultValue.dataId).values[0] as Uint8Array); |
| 71 | outBuf = scatterImpl( |
| 72 | indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, |
| 73 | numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices); |
| 74 | break; |
| 75 | } |
| 76 | default: |
| 77 | throw new Error(`Unsupported type ${sparseValues.dtype}`); |
| 78 | } |
| 79 | return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values); |
| 80 | } |
nothing calls this directly
no test coverage detected
searching dependent graphs…