(args: {
inputs: MaxPool3DGradInputs,
backend: MathBackendCPU,
attrs: MaxPool3DGradAttrs
})
| 22 | import {maxPool3dPositions} from '../utils/pool_utils'; |
| 23 | |
| 24 | export function maxPool3DGrad(args: { |
| 25 | inputs: MaxPool3DGradInputs, |
| 26 | backend: MathBackendCPU, |
| 27 | attrs: MaxPool3DGradAttrs |
| 28 | }): TensorInfo { |
| 29 | const {inputs, backend, attrs} = args; |
| 30 | const {dy, input} = inputs; |
| 31 | const {filterSize, strides, pad, dimRoundingMode} = attrs; |
| 32 | |
| 33 | assertNotComplex([dy, input], 'maxPool3DGrad'); |
| 34 | |
| 35 | const convInfo = backend_util.computePool3DInfo( |
| 36 | input.shape as [number, number, number, number, number], filterSize, |
| 37 | strides, 1 /* dilations */, pad, dimRoundingMode); |
| 38 | |
| 39 | const inputBuf = backend.bufferSync(input); |
| 40 | const maxPosBuf = maxPool3dPositions(inputBuf, convInfo); |
| 41 | const strideDepth = convInfo.strideDepth; |
| 42 | const strideHeight = convInfo.strideHeight; |
| 43 | const strideWidth = convInfo.strideWidth; |
| 44 | const dilationDepth = convInfo.dilationDepth; |
| 45 | const dilationHeight = convInfo.dilationHeight; |
| 46 | const dilationWidth = convInfo.dilationWidth; |
| 47 | const effectiveFilterDepth = convInfo.effectiveFilterDepth; |
| 48 | const effectiveFilterHeight = convInfo.effectiveFilterHeight; |
| 49 | const effectiveFilterWidth = convInfo.effectiveFilterWidth; |
| 50 | const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; |
| 51 | const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; |
| 52 | const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; |
| 53 | const dx = buffer(input.shape, 'float32'); |
| 54 | |
| 55 | const dyBuf = backend.bufferSync<Rank, 'float32'>(dy); |
| 56 | |
| 57 | for (let batch = 0; batch < convInfo.batchSize; ++batch) { |
| 58 | for (let channel = 0; channel < convInfo.inChannels; ++channel) { |
| 59 | for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) { |
| 60 | for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) { |
| 61 | for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) { |
| 62 | // Shader code begins |
| 63 | const dyDepthCorner = dxDepth - padFront; |
| 64 | const dyRowCorner = dxRow - padTop; |
| 65 | const dyColCorner = dxCol - padLeft; |
| 66 | let dotProd = 0; |
| 67 | for (let wDepth = 0; wDepth < effectiveFilterDepth; |
| 68 | wDepth += dilationDepth) { |
| 69 | const dyDepth = (dyDepthCorner + wDepth) / strideDepth; |
| 70 | if (dyDepth < 0 || dyDepth >= convInfo.outDepth || |
| 71 | Math.floor(dyDepth) !== dyDepth) { |
| 72 | continue; |
| 73 | } |
| 74 | for (let wRow = 0; wRow < effectiveFilterHeight; |
| 75 | wRow += dilationHeight) { |
| 76 | const dyRow = (dyRowCorner + wRow) / strideHeight; |
| 77 | if (dyRow < 0 || dyRow >= convInfo.outHeight || |
| 78 | Math.floor(dyRow) !== dyRow) { |
| 79 | continue; |
| 80 | } |
| 81 | for (let wCol = 0; wCol < effectiveFilterWidth; |
nothing calls this directly
no test coverage detected
searching dependent graphs…