Actual computation of gradient of the loss wrt. X, W, and b
(self, dLdY, X, Z)
| 3519 | return dX[0] if len(X) == 1 else dX |
| 3520 | |
| 3521 | def _bwd(self, dLdY, X, Z): |
| 3522 | """Actual computation of gradient of the loss wrt. X, W, and b""" |
| 3523 | W = np.rot90(self.parameters["W"], 2) |
| 3524 | |
| 3525 | s = self.stride |
| 3526 | if self.stride > 1: |
| 3527 | X = dilate(X, s - 1) |
| 3528 | s = 1 |
| 3529 | |
| 3530 | fr, fc, in_ch, out_ch = W.shape |
| 3531 | (fr, fc), p = self.kernel_shape, self.pad |
| 3532 | n_ex, out_rows, out_cols, out_ch = dLdY.shape |
| 3533 | |
| 3534 | # pad X the first time |
| 3535 | X_pad, p = pad2D(X, p, W.shape[:2], s) |
| 3536 | n_ex, in_rows, in_cols, in_ch = X_pad.shape |
| 3537 | pr1, pr2, pc1, pc2 = p |
| 3538 | |
| 3539 | # compute additional padding to produce the deconvolution |
| 3540 | out_rows = s * (in_rows - 1) - pr1 - pr2 + fr |
| 3541 | out_cols = s * (in_cols - 1) - pc1 - pc2 + fc |
| 3542 | out_dim = (out_rows, out_cols) |
| 3543 | |
| 3544 | # add additional "deconvolution" padding |
| 3545 | _p = calc_pad_dims_2D(X_pad.shape, out_dim, W.shape[:2], s, 0) |
| 3546 | X_pad, _ = pad2D(X_pad, _p, W.shape[:2], s) |
| 3547 | |
| 3548 | # columnize W, X, and dLdY |
| 3549 | dLdZ = dLdY * self.act_fn.grad(Z) |
| 3550 | dLdZ, _ = pad2D(dLdZ, p, W.shape[:2], s) |
| 3551 | |
| 3552 | dLdZ_col = dLdZ.transpose(3, 1, 2, 0).reshape(out_ch, -1) |
| 3553 | W_col = W.transpose(3, 2, 0, 1).reshape(out_ch, -1) |
| 3554 | X_col, _ = im2col(X_pad, W.shape, 0, s, 0) |
| 3555 | |
| 3556 | # compute gradients via matrix multiplication and reshape |
| 3557 | dB = dLdZ_col.sum(axis=1).reshape(1, 1, 1, -1) |
| 3558 | dW = (dLdZ_col @ X_col.T).reshape(out_ch, in_ch, fr, fc).transpose(2, 3, 1, 0) |
| 3559 | dW = np.rot90(dW, 2) |
| 3560 | |
| 3561 | # reshape columnized dX back into the same format as the input volume |
| 3562 | dX_col = W_col.T @ dLdZ_col |
| 3563 | |
| 3564 | total_pad = tuple(i + j for i, j in zip(p, _p)) |
| 3565 | dX = col2im(dX_col, X.shape, W.shape, total_pad, s, 0).transpose(0, 2, 3, 1) |
| 3566 | dX = dX[:, :: self.stride, :: self.stride, :] |
| 3567 | |
| 3568 | return dX, dW, dB |
| 3569 | |
| 3570 | |
| 3571 | ####################################################################### |