(
xValues: TypedArray, xShape: number[], dtype: DataType, strides: number[],
convInfo: backend_util.Conv2DInfo,
poolType: 'max'|'avg')
| 18 | import {backend_util, buffer, DataType, Rank, TensorBuffer, TypedArray} from '@tensorflow/tfjs-core'; |
| 19 | |
| 20 | export function pool( |
| 21 | xValues: TypedArray, xShape: number[], dtype: DataType, strides: number[], |
| 22 | convInfo: backend_util.Conv2DInfo, |
| 23 | poolType: 'max'|'avg'): TensorBuffer<Rank, DataType> { |
| 24 | const strideHeight = convInfo.strideHeight; |
| 25 | const strideWidth = convInfo.strideWidth; |
| 26 | const dilationHeight = convInfo.dilationHeight; |
| 27 | const dilationWidth = convInfo.dilationWidth; |
| 28 | const effectiveFilterHeight = convInfo.effectiveFilterHeight; |
| 29 | const effectiveFilterWidth = convInfo.effectiveFilterWidth; |
| 30 | const padTop = convInfo.padInfo.top; |
| 31 | const padLeft = convInfo.padInfo.left; |
| 32 | |
| 33 | const initialValue = |
| 34 | (poolType === 'max' ? Number.NEGATIVE_INFINITY : |
| 35 | Number.POSITIVE_INFINITY); |
| 36 | |
| 37 | const output = buffer(convInfo.outShape, dtype); |
| 38 | const outputVals = output.values; |
| 39 | |
| 40 | const outputBatchStrides = |
| 41 | convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3]; |
| 42 | const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3]; |
| 43 | const outputColStrides = convInfo.outShape[3]; |
| 44 | |
| 45 | for (let b = 0; b < convInfo.batchSize; ++b) { |
| 46 | const outputBatchOffset = b * outputBatchStrides; |
| 47 | const inputBatchOffset = b * strides[0]; |
| 48 | for (let d = 0; d < convInfo.inChannels; ++d) { |
| 49 | for (let yR = 0; yR < convInfo.outHeight; ++yR) { |
| 50 | const xRCorner = yR * strideHeight - padTop; |
| 51 | const xRMin = Math.max(0, xRCorner); |
| 52 | const xRMax = |
| 53 | Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner); |
| 54 | const outputRowOffset = outputBatchOffset + yR * outputRowStrides; |
| 55 | for (let yC = 0; yC < convInfo.outWidth; ++yC) { |
| 56 | const xCCorner = yC * strideWidth - padLeft; |
| 57 | const xCMin = Math.max(0, xCCorner); |
| 58 | const xCMax = |
| 59 | Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner); |
| 60 | let minMaxValue = initialValue; |
| 61 | let avgValue = 0; |
| 62 | let count = 0; |
| 63 | for (let xR = xRMin; xR < xRMax; xR += dilationHeight) { |
| 64 | const xROffset = inputBatchOffset + xR * strides[1]; |
| 65 | for (let xC = xCMin; xC < xCMax; xC += dilationWidth) { |
| 66 | const xCOffset = xROffset + xC * strides[2]; |
| 67 | const pixel = xValues[xCOffset + d]; |
| 68 | if ((poolType === 'max' && pixel > minMaxValue)) { |
| 69 | minMaxValue = pixel; |
| 70 | } else if (poolType === 'avg') { |
| 71 | avgValue += pixel; |
| 72 | count++; |
| 73 | } |
| 74 | } |
| 75 | if (isNaN(minMaxValue)) { |
| 76 | break; |
| 77 | } |
no test coverage detected
searching dependent graphs…