MCPcopy
hub / github.com/tensorflow/tfjs / any

Function any

tfjs-backend-cpu/src/kernels/Any.ts:25–77  ·  view source on GitHub ↗
(
    args: {inputs: AnyInputs, backend: MathBackendCPU, attrs: AnyAttrs})

Source from the content-addressed store, hash-verified

23import {transpose} from './Transpose';
24
25export function any(
26 args: {inputs: AnyInputs, backend: MathBackendCPU, attrs: AnyAttrs}):
27 TensorInfo {
28 const {inputs, backend, attrs} = args;
29 const {x} = inputs;
30 const {axis, keepDims} = attrs;
31
32 assertNotComplex(x, 'any');
33
34 const origAxes = util.parseAxisParam(axis, x.shape);
35 let axes = origAxes;
36 const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length);
37 let $x = x;
38 if (permutedAxes != null) {
39 $x = transpose({inputs: {x}, backend, attrs: {perm: permutedAxes}});
40 axes = backend_util.getInnerMostAxes(axes.length, x.shape.length);
41 }
42
43 backend_util.assertAxesAreInnerMostDims('any', axes, $x.shape.length);
44 const [outShape, reduceShape] =
45 backend_util.computeOutAndReduceShapes($x.shape, axes);
46 const reduceSize = util.sizeFromShape(reduceShape);
47 const vals = util.makeZerosTypedArray(util.sizeFromShape(outShape), $x.dtype);
48
49 const aVals = backend.data.get($x.dataId).values as TypedArray;
50 for (let i = 0; i < vals.length; ++i) {
51 const offset = i * reduceSize;
52 let anyVal = aVals[offset];
53 for (let j = 0; j < reduceSize; ++j) {
54 const value = aVals[offset + j];
55 anyVal = anyVal || value;
56 }
57 vals[i] = anyVal;
58 }
59
60 if (permutedAxes != null) {
61 backend.disposeIntermediateTensorInfo($x);
62 }
63
64 const result = backend.makeTensorInfo(outShape, $x.dtype, vals);
65
66 if (keepDims) {
67 const expandedShape = backend_util.expandShapeToKeepDim(outShape, origAxes);
68 const reshapedResult =
69 reshape({inputs: {x: result}, backend, attrs: {shape: expandedShape}});
70
71 backend.disposeIntermediateTensorInfo(result);
72
73 return reshapedResult;
74 }
75
76 return result;
77}
78
79export const anyConfig: KernelConfig = {
80 kernelName: Any,

Callers 6

computeMaskMethod · 0.50
callMethod · 0.50
any.tsFile · 0.50
detect_saved_modelFunction · 0.50
detect_input_formatFunction · 0.50
validate_input_pathFunction · 0.50

Calls 6

assertNotComplexFunction · 0.90
transposeFunction · 0.90
reshapeFunction · 0.90
getMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected