(
xBuf: TensorBuffer<Rank, DataType>,
convInfo: backend_util.Conv3DInfo)
| 261 | } |
| 262 | |
| 263 | export function maxPool3dPositions( |
| 264 | xBuf: TensorBuffer<Rank, DataType>, |
| 265 | convInfo: backend_util.Conv3DInfo): TensorBuffer<Rank, DataType> { |
| 266 | const maxPositions = buffer(convInfo.outShape, 'int32'); |
| 267 | const strideDepth = convInfo.strideDepth; |
| 268 | const strideHeight = convInfo.strideHeight; |
| 269 | const strideWidth = convInfo.strideWidth; |
| 270 | const dilationDepth = convInfo.dilationDepth; |
| 271 | const dilationHeight = convInfo.dilationHeight; |
| 272 | const dilationWidth = convInfo.dilationWidth; |
| 273 | const effectiveFilterDepth = convInfo.effectiveFilterDepth; |
| 274 | const effectiveFilterHeight = convInfo.effectiveFilterHeight; |
| 275 | const effectiveFilterWidth = convInfo.effectiveFilterWidth; |
| 276 | const padFront = convInfo.padInfo.front; |
| 277 | const padTop = convInfo.padInfo.top; |
| 278 | const padLeft = convInfo.padInfo.left; |
| 279 | |
| 280 | for (let batch = 0; batch < convInfo.batchSize; ++batch) { |
| 281 | for (let channel = 0; channel < convInfo.inChannels; ++channel) { |
| 282 | for (let yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) { |
| 283 | const xDepthCorner = yDepth * strideDepth - padFront; |
| 284 | let xDepthMin = xDepthCorner; |
| 285 | while (xDepthMin < 0) { |
| 286 | xDepthMin += dilationDepth; |
| 287 | } |
| 288 | const xDepthMax = |
| 289 | Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner); |
| 290 | for (let yRow = 0; yRow < convInfo.outHeight; ++yRow) { |
| 291 | const xRowCorner = yRow * strideHeight - padTop; |
| 292 | let xRowMin = xRowCorner; |
| 293 | while (xRowMin < 0) { |
| 294 | xRowMin += dilationHeight; |
| 295 | } |
| 296 | const xRowMax = |
| 297 | Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner); |
| 298 | for (let yCol = 0; yCol < convInfo.outWidth; ++yCol) { |
| 299 | const xColCorner = yCol * strideWidth - padLeft; |
| 300 | let xColMin = xColCorner; |
| 301 | while (xColMin < 0) { |
| 302 | xColMin += dilationWidth; |
| 303 | } |
| 304 | const xColMax = |
| 305 | Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner); |
| 306 | |
| 307 | // Shader code begins |
| 308 | let maxValue = Number.NEGATIVE_INFINITY; |
| 309 | let maxPosition = -1; |
| 310 | |
| 311 | for (let xDepth = xDepthMin; xDepth < xDepthMax; |
| 312 | xDepth += dilationDepth) { |
| 313 | const wDepth = xDepth - xDepthCorner; |
| 314 | for (let xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) { |
| 315 | const wRow = xRow - xRowCorner; |
| 316 | for (let xCol = xColMin; xCol < xColMax; |
| 317 | xCol += dilationWidth) { |
| 318 | const wCol = xCol - xColCorner; |
| 319 | const pixel = xBuf.get(batch, xDepth, xRow, xCol, |
| 320 | channel) as number; |
no test coverage detected
searching dependent graphs…