| 94 | } |
| 95 | |
| 96 | export function topKImpl<T extends Tensor, R extends Rank>( |
| 97 | x: TypedArray, xShape: number[], xDtype: NumericDataType, k: number, |
| 98 | sorted: boolean): |
| 99 | [TensorBuffer<R, NumericDataType>, TensorBuffer<R, 'int32'>] { |
| 100 | // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim. |
| 101 | const lastDim = xShape[xShape.length - 1]; |
| 102 | const [batch, size] = [x.length / lastDim, lastDim]; |
| 103 | const allTopKVals = util.getTypedArrayFromDType(xDtype, batch * k); |
| 104 | const allTopKIndices = util.getTypedArrayFromDType('int32', batch * k); |
| 105 | |
| 106 | for (let b = 0; b < batch; b++) { |
| 107 | const offset = b * size; |
| 108 | const vals = x.subarray(offset, offset + size); |
| 109 | |
| 110 | let valAndInd: Pair[] = new Array(vals.length); |
| 111 | vals.forEach( |
| 112 | (value: number, index: number) => valAndInd[index] = {value, index}); |
| 113 | |
| 114 | if (k < valAndInd.length) { |
| 115 | select(valAndInd, k); |
| 116 | valAndInd = valAndInd.slice(0, k); |
| 117 | } |
| 118 | |
| 119 | if (sorted) { |
| 120 | valAndInd.sort(comparePair); |
| 121 | } |
| 122 | |
| 123 | const outOffset = b * k; |
| 124 | const topKVals = allTopKVals.subarray(outOffset, outOffset + k); |
| 125 | const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k); |
| 126 | for (let i = 0; i < k; i++) { |
| 127 | topKVals[i] = valAndInd[i].value; |
| 128 | topKIndices[i] = valAndInd[i].index; |
| 129 | } |
| 130 | } |
| 131 | // Reshape back to the original input shape, except that the last |
| 132 | // dimension is k. |
| 133 | const outputShape = xShape.slice(); |
| 134 | outputShape[outputShape.length - 1] = k; |
| 135 | |
| 136 | return [ |
| 137 | buffer(outputShape as ShapeMap[R], xDtype, allTopKVals), |
| 138 | buffer(outputShape as ShapeMap[R], 'int32', allTopKIndices) |
| 139 | ]; |
| 140 | } |