MCPcopy
hub / github.com/tensorflow/tfjs / topKImpl

Function topKImpl

tfjs-backend-cpu/src/kernels/TopK_impl.ts:96–140  ·  view source on GitHub ↗
(
    x: TypedArray, xShape: number[], xDtype: NumericDataType, k: number,
    sorted: boolean)

Source from the content-addressed store, hash-verified

94}
95
96export function topKImpl<T extends Tensor, R extends Rank>(
97 x: TypedArray, xShape: number[], xDtype: NumericDataType, k: number,
98 sorted: boolean):
99 [TensorBuffer<R, NumericDataType>, TensorBuffer<R, 'int32'>] {
100 // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
101 const lastDim = xShape[xShape.length - 1];
102 const [batch, size] = [x.length / lastDim, lastDim];
103 const allTopKVals = util.getTypedArrayFromDType(xDtype, batch * k);
104 const allTopKIndices = util.getTypedArrayFromDType('int32', batch * k);
105
106 for (let b = 0; b < batch; b++) {
107 const offset = b * size;
108 const vals = x.subarray(offset, offset + size);
109
110 let valAndInd: Pair[] = new Array(vals.length);
111 vals.forEach(
112 (value: number, index: number) => valAndInd[index] = {value, index});
113
114 if (k < valAndInd.length) {
115 select(valAndInd, k);
116 valAndInd = valAndInd.slice(0, k);
117 }
118
119 if (sorted) {
120 valAndInd.sort(comparePair);
121 }
122
123 const outOffset = b * k;
124 const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
125 const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
126 for (let i = 0; i < k; i++) {
127 topKVals[i] = valAndInd[i].value;
128 topKIndices[i] = valAndInd[i].index;
129 }
130 }
131 // Reshape back to the original input shape, except that the last
132 // dimension is k.
133 const outputShape = xShape.slice();
134 outputShape[outputShape.length - 1] = k;
135
136 return [
137 buffer(outputShape as ShapeMap[R], xDtype, allTopKVals),
138 buffer(outputShape as ShapeMap[R], 'int32', allTopKIndices)
139 ];
140}

Callers 1

topKFunction · 0.90

Calls 3

bufferFunction · 0.90
selectFunction · 0.70
sliceMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…