MCPcopy Index your code
hub / github.com/tensorflow/tfjs / maxPool3DGrad

Function maxPool3DGrad

tfjs-backend-cpu/src/kernels/MaxPool3DGrad.ts:24–117  ·  view source on GitHub ↗
(args: {
  inputs: MaxPool3DGradInputs,
  backend: MathBackendCPU,
  attrs: MaxPool3DGradAttrs
})

Source from the content-addressed store, hash-verified

22import {maxPool3dPositions} from '../utils/pool_utils';
23
24export function maxPool3DGrad(args: {
25 inputs: MaxPool3DGradInputs,
26 backend: MathBackendCPU,
27 attrs: MaxPool3DGradAttrs
28}): TensorInfo {
29 const {inputs, backend, attrs} = args;
30 const {dy, input} = inputs;
31 const {filterSize, strides, pad, dimRoundingMode} = attrs;
32
33 assertNotComplex([dy, input], 'maxPool3DGrad');
34
35 const convInfo = backend_util.computePool3DInfo(
36 input.shape as [number, number, number, number, number], filterSize,
37 strides, 1 /* dilations */, pad, dimRoundingMode);
38
39 const inputBuf = backend.bufferSync(input);
40 const maxPosBuf = maxPool3dPositions(inputBuf, convInfo);
41 const strideDepth = convInfo.strideDepth;
42 const strideHeight = convInfo.strideHeight;
43 const strideWidth = convInfo.strideWidth;
44 const dilationDepth = convInfo.dilationDepth;
45 const dilationHeight = convInfo.dilationHeight;
46 const dilationWidth = convInfo.dilationWidth;
47 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
48 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
49 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
50 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
51 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
52 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
53 const dx = buffer(input.shape, 'float32');
54
55 const dyBuf = backend.bufferSync<Rank, 'float32'>(dy);
56
57 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
58 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
59 for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
60 for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
61 for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
62 // Shader code begins
63 const dyDepthCorner = dxDepth - padFront;
64 const dyRowCorner = dxRow - padTop;
65 const dyColCorner = dxCol - padLeft;
66 let dotProd = 0;
67 for (let wDepth = 0; wDepth < effectiveFilterDepth;
68 wDepth += dilationDepth) {
69 const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
70 if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
71 Math.floor(dyDepth) !== dyDepth) {
72 continue;
73 }
74 for (let wRow = 0; wRow < effectiveFilterHeight;
75 wRow += dilationHeight) {
76 const dyRow = (dyRowCorner + wRow) / strideHeight;
77 if (dyRow < 0 || dyRow >= convInfo.outHeight ||
78 Math.floor(dyRow) !== dyRow) {
79 continue;
80 }
81 for (let wCol = 0; wCol < effectiveFilterWidth;

Callers

nothing calls this directly

Calls 8

assertNotComplexFunction · 0.90
maxPool3dPositionsFunction · 0.90
bufferFunction · 0.90
floorMethod · 0.80
bufferSyncMethod · 0.45
getMethod · 0.45
setMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…