MCPcopy
hub / github.com/tensorflow/tfjs / avgPool3DGrad

Function avgPool3DGrad

tfjs-backend-cpu/src/kernels/AvgPool3DGrad.ts:23–106  ·  view source on GitHub ↗
(args: {
  inputs: AvgPool3DGradInputs,
  backend: MathBackendCPU,
  attrs: AvgPool3DGradAttrs
})

Source from the content-addressed store, hash-verified

21import {assertNotComplex} from '../cpu_util';
22
23export function avgPool3DGrad(args: {
24 inputs: AvgPool3DGradInputs,
25 backend: MathBackendCPU,
26 attrs: AvgPool3DGradAttrs
27}): TensorInfo {
28 const {inputs, backend, attrs} = args;
29 const {dy, input} = inputs;
30 const {filterSize, strides, pad, dimRoundingMode} = attrs;
31
32 assertNotComplex([dy, input], 'avgPool3DGrad');
33
34 const convInfo = backend_util.computePool3DInfo(
35 input.shape as [number, number, number, number, number], filterSize,
36 strides, 1 /* dilations */, pad, dimRoundingMode);
37
38 const strideDepth = convInfo.strideDepth;
39 const strideHeight = convInfo.strideHeight;
40 const strideWidth = convInfo.strideWidth;
41 const filterDepth = convInfo.filterDepth;
42 const filterHeight = convInfo.filterHeight;
43 const filterWidth = convInfo.filterWidth;
44 const dilationDepth = convInfo.dilationDepth;
45 const dilationHeight = convInfo.dilationHeight;
46 const dilationWidth = convInfo.dilationWidth;
47 const effectiveFilterDepth = convInfo.effectiveFilterDepth;
48 const effectiveFilterHeight = convInfo.effectiveFilterHeight;
49 const effectiveFilterWidth = convInfo.effectiveFilterWidth;
50 const padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
51 const padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
52 const padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
53 const dx = buffer(input.shape, 'float32');
54
55 const avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
56
57 const dyBuf = backend.bufferSync<Rank, 'float32'>(dy);
58
59 for (let batch = 0; batch < convInfo.batchSize; ++batch) {
60 for (let channel = 0; channel < convInfo.inChannels; ++channel) {
61 for (let dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) {
62 for (let dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) {
63 for (let dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) {
64 // Shader code begins.
65 const dyDepthCorner = dxDepth - padFront;
66 const dyRowCorner = dxRow - padTop;
67 const dyColCorner = dxCol - padLeft;
68 let dotProd = 0;
69 for (let wDepth = 0; wDepth < effectiveFilterDepth;
70 wDepth += dilationDepth) {
71 const dyDepth = (dyDepthCorner + wDepth) / strideDepth;
72 if (dyDepth < 0 || dyDepth >= convInfo.outDepth ||
73 Math.floor(dyDepth) !== dyDepth) {
74 continue;
75 }
76 for (let wRow = 0; wRow < effectiveFilterHeight;
77 wRow += dilationHeight) {
78 const dyRow = (dyRowCorner + wRow) / strideHeight;
79 if (dyRow < 0 || dyRow >= convInfo.outHeight ||
80 Math.floor(dyRow) !== dyRow) {

Callers

nothing calls this directly

Calls 7

assertNotComplexFunction · 0.90
bufferFunction · 0.90
floorMethod · 0.80
bufferSyncMethod · 0.45
getMethod · 0.45
setMethod · 0.45
makeTensorInfoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…