| 154 | } |
| 155 | |
| 156 | export function pool3d( |
| 157 | xValues: TypedArray, xShape: number[], dtype: DataType, strides: number[], |
| 158 | convInfo: backend_util.Conv3DInfo, |
| 159 | poolType: 'max'|'avg'): TensorBuffer<Rank, DataType> { |
| 160 | const strideDepth = convInfo.strideDepth; |
| 161 | const strideHeight = convInfo.strideHeight; |
| 162 | const strideWidth = convInfo.strideWidth; |
| 163 | const dilationDepth = convInfo.dilationDepth; |
| 164 | const dilationHeight = convInfo.dilationHeight; |
| 165 | const dilationWidth = convInfo.dilationWidth; |
| 166 | const effectiveFilterDepth = convInfo.effectiveFilterDepth; |
| 167 | const effectiveFilterHeight = convInfo.effectiveFilterHeight; |
| 168 | const effectiveFilterWidth = convInfo.effectiveFilterWidth; |
| 169 | const padFront = convInfo.padInfo.front; |
| 170 | const padTop = convInfo.padInfo.top; |
| 171 | const padLeft = convInfo.padInfo.left; |
| 172 | |
| 173 | const initialValue = |
| 174 | (poolType === 'max' ? Number.NEGATIVE_INFINITY : |
| 175 | Number.POSITIVE_INFINITY); |
| 176 | |
| 177 | const output = buffer(convInfo.outShape, dtype); |
| 178 | const outputVals = output.values; |
| 179 | |
| 180 | const outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * |
| 181 | convInfo.outShape[3] * convInfo.outShape[4]; |
| 182 | const outputDepthStrides = |
| 183 | convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4]; |
| 184 | const outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4]; |
| 185 | const outputColStrides = convInfo.outShape[4]; |
| 186 | |
| 187 | for (let batch = 0; batch < convInfo.batchSize; ++batch) { |
| 188 | const outputBatchOffset = batch * outputBatchStrides; |
| 189 | const inputBatchOffset = batch * strides[0]; |
| 190 | for (let channel = 0; channel < convInfo.inChannels; ++channel) { |
| 191 | for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) { |
| 192 | const xDepthCorner = yDepth * strideDepth - padFront; |
| 193 | let xDepthMin = xDepthCorner; |
| 194 | while (xDepthMin < 0) { |
| 195 | xDepthMin += dilationDepth; |
| 196 | } |
| 197 | const xDepthMax = |
| 198 | Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner); |
| 199 | const outputDepthOffset = |
| 200 | outputBatchOffset + yDepth * outputDepthStrides; |
| 201 | for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) { |
| 202 | const xRowCorner = yRow * strideHeight - padTop; |
| 203 | let xRowMin = xRowCorner; |
| 204 | while (xRowMin < 0) { |
| 205 | xRowMin += dilationHeight; |
| 206 | } |
| 207 | const xRowMax = |
| 208 | Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner); |
| 209 | const outputRowOffset = outputDepthOffset + yRow * outputRowStrides; |
| 210 | for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) { |
| 211 | const xColCorner = yCol * strideWidth - padLeft; |
| 212 | let xColMin = xColCorner; |
| 213 | while (xColMin < 0) { |