(
args: {inputs: AnyInputs, backend: MathBackendCPU, attrs: AnyAttrs})
| 23 | import {transpose} from './Transpose'; |
| 24 | |
| 25 | export function any( |
| 26 | args: {inputs: AnyInputs, backend: MathBackendCPU, attrs: AnyAttrs}): |
| 27 | TensorInfo { |
| 28 | const {inputs, backend, attrs} = args; |
| 29 | const {x} = inputs; |
| 30 | const {axis, keepDims} = attrs; |
| 31 | |
| 32 | assertNotComplex(x, 'any'); |
| 33 | |
| 34 | const origAxes = util.parseAxisParam(axis, x.shape); |
| 35 | let axes = origAxes; |
| 36 | const permutedAxes = backend_util.getAxesPermutation(axes, x.shape.length); |
| 37 | let $x = x; |
| 38 | if (permutedAxes != null) { |
| 39 | $x = transpose({inputs: {x}, backend, attrs: {perm: permutedAxes}}); |
| 40 | axes = backend_util.getInnerMostAxes(axes.length, x.shape.length); |
| 41 | } |
| 42 | |
| 43 | backend_util.assertAxesAreInnerMostDims('any', axes, $x.shape.length); |
| 44 | const [outShape, reduceShape] = |
| 45 | backend_util.computeOutAndReduceShapes($x.shape, axes); |
| 46 | const reduceSize = util.sizeFromShape(reduceShape); |
| 47 | const vals = util.makeZerosTypedArray(util.sizeFromShape(outShape), $x.dtype); |
| 48 | |
| 49 | const aVals = backend.data.get($x.dataId).values as TypedArray; |
| 50 | for (let i = 0; i < vals.length; ++i) { |
| 51 | const offset = i * reduceSize; |
| 52 | let anyVal = aVals[offset]; |
| 53 | for (let j = 0; j < reduceSize; ++j) { |
| 54 | const value = aVals[offset + j]; |
| 55 | anyVal = anyVal || value; |
| 56 | } |
| 57 | vals[i] = anyVal; |
| 58 | } |
| 59 | |
| 60 | if (permutedAxes != null) { |
| 61 | backend.disposeIntermediateTensorInfo($x); |
| 62 | } |
| 63 | |
| 64 | const result = backend.makeTensorInfo(outShape, $x.dtype, vals); |
| 65 | |
| 66 | if (keepDims) { |
| 67 | const expandedShape = backend_util.expandShapeToKeepDim(outShape, origAxes); |
| 68 | const reshapedResult = |
| 69 | reshape({inputs: {x: result}, backend, attrs: {shape: expandedShape}}); |
| 70 | |
| 71 | backend.disposeIntermediateTensorInfo(result); |
| 72 | |
| 73 | return reshapedResult; |
| 74 | } |
| 75 | |
| 76 | return result; |
| 77 | } |
| 78 | |
| 79 | export const anyConfig: KernelConfig = { |
| 80 | kernelName: Any, |
no test coverage detected