MCPcopy
hub / github.com/ddbourgin/numpy-ml / _bwd

Method _bwd

numpy_ml/neural_nets/layers/layers.py:3521–3568  ·  view source on GitHub ↗

Actual computation of gradient of the loss wrt. X, W, and b

(self, dLdY, X, Z)

Source from the content-addressed store, hash-verified

3519 return dX[0] if len(X) == 1 else dX
3520
3521 def _bwd(self, dLdY, X, Z):
3522 """Actual computation of gradient of the loss wrt. X, W, and b"""
3523 W = np.rot90(self.parameters["W"], 2)
3524
3525 s = self.stride
3526 if self.stride > 1:
3527 X = dilate(X, s - 1)
3528 s = 1
3529
3530 fr, fc, in_ch, out_ch = W.shape
3531 (fr, fc), p = self.kernel_shape, self.pad
3532 n_ex, out_rows, out_cols, out_ch = dLdY.shape
3533
3534 # pad X the first time
3535 X_pad, p = pad2D(X, p, W.shape[:2], s)
3536 n_ex, in_rows, in_cols, in_ch = X_pad.shape
3537 pr1, pr2, pc1, pc2 = p
3538
3539 # compute additional padding to produce the deconvolution
3540 out_rows = s * (in_rows - 1) - pr1 - pr2 + fr
3541 out_cols = s * (in_cols - 1) - pc1 - pc2 + fc
3542 out_dim = (out_rows, out_cols)
3543
3544 # add additional "deconvolution" padding
3545 _p = calc_pad_dims_2D(X_pad.shape, out_dim, W.shape[:2], s, 0)
3546 X_pad, _ = pad2D(X_pad, _p, W.shape[:2], s)
3547
3548 # columnize W, X, and dLdY
3549 dLdZ = dLdY * self.act_fn.grad(Z)
3550 dLdZ, _ = pad2D(dLdZ, p, W.shape[:2], s)
3551
3552 dLdZ_col = dLdZ.transpose(3, 1, 2, 0).reshape(out_ch, -1)
3553 W_col = W.transpose(3, 2, 0, 1).reshape(out_ch, -1)
3554 X_col, _ = im2col(X_pad, W.shape, 0, s, 0)
3555
3556 # compute gradients via matrix multiplication and reshape
3557 dB = dLdZ_col.sum(axis=1).reshape(1, 1, 1, -1)
3558 dW = (dLdZ_col @ X_col.T).reshape(out_ch, in_ch, fr, fc).transpose(2, 3, 1, 0)
3559 dW = np.rot90(dW, 2)
3560
3561 # reshape columnized dX back into the same format as the input volume
3562 dX_col = W_col.T @ dLdZ_col
3563
3564 total_pad = tuple(i + j for i, j in zip(p, _p))
3565 dX = col2im(dX_col, X.shape, W.shape, total_pad, s, 0).transpose(0, 2, 3, 1)
3566 dX = dX[:, :: self.stride, :: self.stride, :]
3567
3568 return dX, dW, dB
3569
3570
3571#######################################################################

Callers 1

backwardMethod · 0.95

Calls 6

dilateFunction · 0.85
pad2DFunction · 0.85
calc_pad_dims_2DFunction · 0.85
im2colFunction · 0.85
col2imFunction · 0.85
gradMethod · 0.45

Tested by

no test coverage detected