MCPcopy Index your code
hub / github.com/tensorflow/tfjs / sparseToDense

Function sparseToDense

tfjs-backend-cpu/src/kernels/SparseToDense.ts:23–80  ·  view source on GitHub ↗
(args: {
  inputs: SparseToDenseInputs,
  backend: MathBackendCPU,
  attrs: SparseToDenseAttrs
})

Source from the content-addressed store, hash-verified

21import {scatterImpl} from './Scatter_impl';
22
23export function sparseToDense(args: {
24 inputs: SparseToDenseInputs,
25 backend: MathBackendCPU,
26 attrs: SparseToDenseAttrs
27}): TensorInfo {
28 const {inputs, backend, attrs} = args;
29 const {sparseIndices, sparseValues, defaultValue} = inputs;
30 const {outputShape} = attrs;
31
32 const {sliceRank, numUpdates, sliceSize, strides, outputSize} =
33 backend_util.calculateShapes(sparseValues, sparseIndices, outputShape);
34 const sumDupeIndices = false;
35
36 const indicesBuf = backend.bufferSync<Rank, 'int32'>(sparseIndices);
37
38 let outBuf;
39 switch (sparseValues.dtype) {
40 case 'bool': {
41 const updatesBuf = backend.bufferSync<Rank, 'bool'>(sparseValues);
42 const $defaultValue =
43 Boolean(backend.data.get(defaultValue.dataId).values[0]);
44 outBuf = scatterImpl(
45 indicesBuf, updatesBuf, outputShape, outputSize, sliceSize,
46 numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
47 break;
48 }
49 case 'float32': {
50 const updatesBuf = backend.bufferSync<Rank, 'float32'>(sparseValues);
51 const $defaultValue =
52 backend.data.get(defaultValue.dataId).values[0] as number;
53 outBuf = scatterImpl(
54 indicesBuf, updatesBuf, outputShape, outputSize, sliceSize,
55 numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
56 break;
57 }
58 case 'int32': {
59 const updatesBuf = backend.bufferSync<Rank, 'int32'>(sparseValues);
60 const $defaultValue =
61 backend.data.get(defaultValue.dataId).values[0] as number;
62 outBuf = scatterImpl(
63 indicesBuf, updatesBuf, outputShape, outputSize, sliceSize,
64 numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
65 break;
66 }
67 case 'string': {
68 const updatesBuf = backend.bufferSync<Rank, 'string'>(sparseValues);
69 const $defaultValue = util.decodeString(
70 backend.data.get(defaultValue.dataId).values[0] as Uint8Array);
71 outBuf = scatterImpl(
72 indicesBuf, updatesBuf, outputShape, outputSize, sliceSize,
73 numUpdates, sliceRank, strides, $defaultValue, sumDupeIndices);
74 break;
75 }
76 default:
77 throw new Error(`Unsupported type ${sparseValues.dtype}`);
78 }
79 return backend.makeTensorInfo(outputShape, outBuf.dtype, outBuf.values);
80}

Callers

nothing calls this directly

Calls 4

scatterImplFunction · 0.90
bufferSyncMethod · 0.45
getMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…