MCPcopy Index your code
hub / github.com/tensorflow/tfjs / sparseToDense

Function sparseToDense

tfjs-backend-webgpu/src/kernels/SparseToDense.ts:28–130  ·  view source on GitHub ↗
(args: {
  inputs: SparseToDenseInputs,
  backend: WebGPUBackend,
  attrs: SparseToDenseAttrs
})

Source from the content-addressed store, hash-verified

26import {tile} from './Tile';
27
28export function sparseToDense(args: {
29 inputs: SparseToDenseInputs,
30 backend: WebGPUBackend,
31 attrs: SparseToDenseAttrs
32}): TensorInfo {
33 const {inputs, backend, attrs} = args;
34 const {sparseIndices, sparseValues, defaultValue} = inputs;
35 const {outputShape} = attrs;
36
37 const {sliceRank, numUpdates, sliceSize, strides, outputSize} =
38 backend_util.calculateShapes(sparseValues, sparseIndices, outputShape);
39
40 const sumDupeIndices = false;
41 if (sparseValues.dtype === 'string') {
42 const indicesBuf = backend.bufferSync<Rank, 'int32'>(sparseIndices);
43 const updatesBuf = backend.bufferSync<Rank, 'string'>(sparseValues);
44 const $defaultValue = util.decodeString(
45 backend.readSync(defaultValue.dataId)[0] as Uint8Array);
46 const outBuf = scatterImplCPU(
47 indicesBuf, updatesBuf, outputShape, outputSize, sliceSize, numUpdates,
48 sliceRank, strides, $defaultValue, sumDupeIndices);
49 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
50 }
51
52 const flattenShape = [outputSize / sliceSize, sliceSize];
53
54 const $sparseIndices = reshape({
55 inputs: {x: sparseIndices},
56 backend,
57 attrs: {shape: [numUpdates, sliceRank]}
58 });
59 const $sparseValues = sparseValues.shape.length ?
60 reshape({
61 inputs: {x: sparseValues},
62 backend,
63 attrs: {shape: [numUpdates, sliceSize]}
64 }) :
65 identity({inputs: {x: sparseValues}, backend});
66
67 const type = $sparseValues.dtype;
68 const zero =
69 backend.makeTensorInfo([], type, util.makeZerosTypedArray(1, type));
70
71 // Fill output tensor with the default value.
72 const $defaultValue = reshape({
73 inputs: {x: defaultValue},
74 backend,
75 attrs: {shape: Array(flattenShape.length).fill(1)}
76 });
77 const $denseValues =
78 tile({inputs: {x: $defaultValue}, backend, attrs: {reps: flattenShape}});
79
80 const size = util.sizeFromShape([numUpdates, sliceSize]);
81 const uniformData = [
82 {type: 'int32', data: [sliceRank]},
83 {type: 'int32', data: strides},
84 {type: 'int32', data: [size]},
85 ];

Callers

nothing calls this directly

Calls 8

reshapeFunction · 0.90
identityFunction · 0.90
tileFunction · 0.90
runWebGPUProgramMethod · 0.80
readSyncMethod · 0.65
disposeDataMethod · 0.65
bufferSyncMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…