MCPcopy
hub / github.com/tensorflow/tfjs / pool

Function pool

tfjs-backend-cpu/src/utils/pool_utils.ts:20–87  ·  view source on GitHub ↗
(
    xValues: TypedArray, xShape: number[], dtype: DataType, strides: number[],
    convInfo: backend_util.Conv2DInfo,
    poolType: 'max'|'avg')

Source from the content-addressed store, hash-verified

18import {backend_util, buffer, DataType, Rank, TensorBuffer, TypedArray} from '@tensorflow/tfjs-core';
19
20export function pool(
21 xValues: TypedArray, xShape: number[], dtype: DataType, strides: number[],
22 convInfo: backend_util.Conv2DInfo,
23 poolType: 'max'|'avg'): TensorBuffer<Rank, DataType> {
24 const strideHeight = convInfo.strideHeight;
25 const strideWidth = convInfo.strideWidth;
26 const dilationHeight = convInfo.dilationHeight;
27 const dilationWidth = convInfo.dilationWidth;
28 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
29 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
30 const padTop = convInfo.padInfo.top;
31 const padLeft = convInfo.padInfo.left;
32
33 const initialValue =
34 (poolType === 'max' ? Number.NEGATIVE_INFINITY :
35 Number.POSITIVE_INFINITY);
36
37 const output = buffer(convInfo.outShape, dtype);
38 const outputVals = output.values;
39
40 const outputBatchStrides =
41 convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3];
42 const outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3];
43 const outputColStrides = convInfo.outShape[3];
44
45 for (let b = 0; b < convInfo.batchSize; ++b) {
46 const outputBatchOffset = b * outputBatchStrides;
47 const inputBatchOffset = b * strides[0];
48 for (let d = 0; d < convInfo.inChannels; ++d) {
49 for (let yR = 0; yR < convInfo.outHeight; ++yR) {
50 const xRCorner = yR * strideHeight - padTop;
51 const xRMin = Math.max(0, xRCorner);
52 const xRMax =
53 Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner);
54 const outputRowOffset = outputBatchOffset + yR * outputRowStrides;
55 for (let yC = 0; yC < convInfo.outWidth; ++yC) {
56 const xCCorner = yC * strideWidth - padLeft;
57 const xCMin = Math.max(0, xCCorner);
58 const xCMax =
59 Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner);
60 let minMaxValue = initialValue;
61 let avgValue = 0;
62 let count = 0;
63 for (let xR = xRMin; xR < xRMax; xR += dilationHeight) {
64 const xROffset = inputBatchOffset + xR * strides[1];
65 for (let xC = xCMin; xC < xCMax; xC += dilationWidth) {
66 const xCOffset = xROffset + xC * strides[2];
67 const pixel = xValues[xCOffset + d];
68 if ((poolType === 'max' && pixel > minMaxValue)) {
69 minMaxValue = pixel;
70 } else if (poolType === 'avg') {
71 avgValue += pixel;
72 count++;
73 }
74 }
75 if (isNaN(minMaxValue)) {
76 break;
77 }

Callers 5

maxPoolWithArgmaxImplFunction · 0.90
avgPoolFunction · 0.90
maxPoolFunction · 0.90
pool.tsFile · 0.90
identityPoolTestFunction · 0.85

Calls 3

bufferFunction · 0.90
maxMethod · 0.80
minMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…