(
args:
{inputs: MaxPoolInputs, backend: MathBackendCPU, attrs: MaxPoolAttrs})
| 22 | import {identity} from './Identity'; |
| 23 | |
| 24 | export function maxPool( |
| 25 | args: |
| 26 | {inputs: MaxPoolInputs, backend: MathBackendCPU, attrs: MaxPoolAttrs}): |
| 27 | TensorInfo { |
| 28 | const {inputs, backend, attrs} = args; |
| 29 | const {x} = inputs; |
| 30 | assertNotComplex(x, 'maxPool'); |
| 31 | const {filterSize, strides, pad, dimRoundingMode} = attrs; |
| 32 | const dilations = 1; |
| 33 | |
| 34 | util.assert( |
| 35 | backend_util.eitherStridesOrDilationsAreOne(strides, dilations), |
| 36 | () => 'Error in maxPool: Either strides or dilations must be 1. ' + |
| 37 | `Got strides ${strides} and dilations '${dilations}'`); |
| 38 | |
| 39 | const convInfo = backend_util.computePool2DInfo( |
| 40 | x.shape as [number, number, number, number], filterSize, strides, |
| 41 | dilations, pad, dimRoundingMode); |
| 42 | let res: TensorInfo; |
| 43 | |
| 44 | if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && |
| 45 | util.arraysEqual(convInfo.inShape, convInfo.outShape)) { |
| 46 | res = identity({inputs: {x}, backend}); |
| 47 | } else { |
| 48 | const xValues = backend.data.get(x.dataId).values as TypedArray; |
| 49 | const strides = util.computeStrides(x.shape); |
| 50 | const buffer = pool(xValues, x.shape, x.dtype, strides, convInfo, 'max'); |
| 51 | res = backend.makeTensorInfo( |
| 52 | convInfo.outShape, x.dtype, buffer.values as TypedArray); |
| 53 | } |
| 54 | return res; |
| 55 | } |
| 56 | |
| 57 | export const maxPoolConfig: KernelConfig = { |
| 58 | kernelName: MaxPool, |
no test coverage detected
searching dependent graphs…