MCPcopy
hub / github.com/tensorflow/tfjs / argMax

Function argMax

tfjs-backend-cpu/src/kernels/ArgMax.ts:24–71  ·  view source on GitHub ↗
(
    args: {inputs: ArgMaxInputs, backend: MathBackendCPU, attrs: ArgMaxAttrs})

Source from the content-addressed store, hash-verified

22import {transpose} from './Transpose';
23
24export function argMax(
25 args: {inputs: ArgMaxInputs, backend: MathBackendCPU, attrs: ArgMaxAttrs}):
26 TensorInfo {
27 const {inputs, backend, attrs} = args;
28 const {x} = inputs;
29 const {axis} = attrs;
30
31 assertNotComplex(x, 'argMax');
32
33 let axes = util.parseAxisParam(axis, x.shape);
34 const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length);
35 let $x = x;
36 const intermediateTensorInfos = [];
37 if (permutedAxes != null) {
38 $x = transpose({inputs: {x}, backend, attrs: {perm: permutedAxes}});
39 intermediateTensorInfos.push($x);
40 axes = backend_util.getInnerMostAxes(axes.length, $x.shape.length);
41 }
42
43 axes = [axes[0]];
44 backend_util.assertAxesAreInnerMostDims('argMax', axes, $x.shape.length);
45 const [outShape, reduceShape] =
46 backend_util.computeOutAndReduceShapes($x.shape, axes);
47
48 const outSize = util.sizeFromShape(outShape);
49 const vals = util.makeZerosTypedArray(outSize, 'int32');
50 const reduceSize = util.sizeFromShape(reduceShape);
51
52 const aVals = backend.data.get($x.dataId).values as TypedArray;
53 for (let i = 0; i < vals.length; ++i) {
54 const offset = i * reduceSize;
55 let max = aVals[offset];
56 let maxIndex = 0;
57 for (let j = 0; j < reduceSize; ++j) {
58 const value = aVals[offset + j];
59 if (value > max) {
60 max = value;
61 maxIndex = j;
62 }
63 }
64 vals[i] = maxIndex;
65 }
66
67 intermediateTensorInfos.forEach(
68 t => backend.disposeIntermediateTensorInfo(t));
69
70 return backend.makeTensorInfo(outShape, 'int32', vals);
71}
72
73export const argMaxConfig: KernelConfig = {
74 kernelName: ArgMax,

Callers 2

standardizeWeightsFunction · 0.50
arg_max.tsFile · 0.50

Calls 6

assertNotComplexFunction · 0.90
transposeFunction · 0.90
pushMethod · 0.45
getMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…