(args: {
inputs: SparseToDenseInputs,
backend: WebGPUBackend,
attrs: SparseToDenseAttrs
})
| 26 | import {tile} from './Tile'; |
| 27 | |
| 28 | export function sparseToDense(args: { |
| 29 | inputs: SparseToDenseInputs, |
| 30 | backend: WebGPUBackend, |
| 31 | attrs: SparseToDenseAttrs |
| 32 | }): TensorInfo { |
| 33 | const {inputs, backend, attrs} = args; |
| 34 | const {sparseIndices, sparseValues, defaultValue} = inputs; |
| 35 | const {outputShape} = attrs; |
| 36 | |
| 37 | const {sliceRank, numUpdates, sliceSize, strides, outputSize} = |
| 38 | backend_util.calculateShapes(sparseValues, sparseIndices, outputShape); |
| 39 | |
| 40 | const sumDupeIndices = false; |
| 41 | if (sparseValues.dtype === 'string') { |
| 42 | const indicesBuf = backend.bufferSync<Rank, 'int32'>(sparseIndices); |
| 43 | const updatesBuf = backend.bufferSync<Rank, 'string'>(sparseValues); |
| 44 | const $defaultValue = util.decodeString( |
| 45 | backend.readSync(defaultValue.dataId)[0] as Uint8Array); |
| 46 | const outBuf = scatterImplCPU( |
| 47 | indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates, |
| 48 | sliceRank, strides, $defaultValue, sumDupeIndices); |
| 49 | return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values); |
| 50 | } |
| 51 | |
| 52 | const flattenShape = [outputSize / sliceSize, sliceSize]; |
| 53 | |
| 54 | const $sparseIndices = reshape({ |
| 55 | inputs: {x: sparseIndices}, |
| 56 | backend, |
| 57 | attrs: {shape: [numUpdates, sliceRank]} |
| 58 | }); |
| 59 | const $sparseValues = sparseValues.shape.length ? |
| 60 | reshape({ |
| 61 | inputs: {x: sparseValues}, |
| 62 | backend, |
| 63 | attrs: {shape: [numUpdates, sliceSize]} |
| 64 | }) : |
| 65 | identity({inputs: {x: sparseValues}, backend}); |
| 66 | |
| 67 | const type = $sparseValues.dtype; |
| 68 | const zero = |
| 69 | backend.makeTensorInfo([], type, util.makeZerosTypedArray(1, type)); |
| 70 | |
| 71 | // Fill output tensor with the default value. |
| 72 | const $defaultValue = reshape({ |
| 73 | inputs: {x: defaultValue}, |
| 74 | backend, |
| 75 | attrs: {shape: Array(flattenShape.length).fill(1)} |
| 76 | }); |
| 77 | const $denseValues = |
| 78 | tile({inputs: {x: $defaultValue}, backend, attrs: {reps: flattenShape}}); |
| 79 | |
| 80 | const size = util.sizeFromShape([numUpdates, sliceSize]); |
| 81 | const uniformData = [ |
| 82 | {type: 'int32', data: [sliceRank]}, |
| 83 | {type: 'int32', data: strides}, |
| 84 | {type: 'int32', data: [size]}, |
| 85 | ]; |
nothing calls this directly
no test coverage detected
searching dependent graphs…