(
args: {inputs: ArgMaxInputs, backend: MathBackendCPU, attrs: ArgMaxAttrs})
| 22 | import {transpose} from './Transpose'; |
| 23 | |
| 24 | export function argMax( |
| 25 | args: {inputs: ArgMaxInputs, backend: MathBackendCPU, attrs: ArgMaxAttrs}): |
| 26 | TensorInfo { |
| 27 | const {inputs, backend, attrs} = args; |
| 28 | const {x} = inputs; |
| 29 | const {axis} = attrs; |
| 30 | |
| 31 | assertNotComplex(x, 'argMax'); |
| 32 | |
| 33 | let axes = util.parseAxisParam(axis, x.shape); |
| 34 | const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length); |
| 35 | let $x = x; |
| 36 | const intermediateTensorInfos = []; |
| 37 | if (permutedAxes != null) { |
| 38 | $x = transpose({inputs: {x}, backend, attrs: {perm: permutedAxes}}); |
| 39 | intermediateTensorInfos.push($x); |
| 40 | axes = backend_util.getInnerMostAxes(axes.length, $x.shape.length); |
| 41 | } |
| 42 | |
| 43 | axes = [axes[0]]; |
| 44 | backend_util.assertAxesAreInnerMostDims('argMax', axes, $x.shape.length); |
| 45 | const [outShape, reduceShape] = |
| 46 | backend_util.computeOutAndReduceShapes($x.shape, axes); |
| 47 | |
| 48 | const outSize = util.sizeFromShape(outShape); |
| 49 | const vals = util.makeZerosTypedArray(outSize, 'int32'); |
| 50 | const reduceSize = util.sizeFromShape(reduceShape); |
| 51 | |
| 52 | const aVals = backend.data.get($x.dataId).values as TypedArray; |
| 53 | for (let i = 0; i < vals.length; ++i) { |
| 54 | const offset = i * reduceSize; |
| 55 | let max = aVals[offset]; |
| 56 | let maxIndex = 0; |
| 57 | for (let j = 0; j < reduceSize; ++j) { |
| 58 | const value = aVals[offset + j]; |
| 59 | if (value > max) { |
| 60 | max = value; |
| 61 | maxIndex = j; |
| 62 | } |
| 63 | } |
| 64 | vals[i] = maxIndex; |
| 65 | } |
| 66 | |
| 67 | intermediateTensorInfos.forEach( |
| 68 | t => backend.disposeIntermediateTensorInfo(t)); |
| 69 | |
| 70 | return backend.makeTensorInfo(outShape, 'int32', vals); |
| 71 | } |
| 72 | |
| 73 | export const argMaxConfig: KernelConfig = { |
| 74 | kernelName: ArgMax, |
no test coverage detected
searching dependent graphs…