Actual computation of gradient of the loss wrt. X, W, and b
(self, dLdy, X, Z)
| 2797 | return dX[0] if len(X) == 1 else dX |
| 2798 | |
| 2799 | def _bwd(self, dLdy, X, Z): |
| 2800 | """Actual computation of gradient of the loss wrt. X, W, and b""" |
| 2801 | W = self.parameters["W"] |
| 2802 | |
| 2803 | # add a row dimension to X, W, and dZ to permit us to use im2col/col2im |
| 2804 | X2D = np.expand_dims(X, axis=1) |
| 2805 | W2D = np.expand_dims(W, axis=0) |
| 2806 | dLdZ = np.expand_dims(dLdy * self.act_fn.grad(Z), axis=1) |
| 2807 | |
| 2808 | d = self.dilation |
| 2809 | fr, fc, in_ch, out_ch = W2D.shape |
| 2810 | n_ex, l_out, out_ch = dLdy.shape |
| 2811 | fr, fc, s = 1, self.kernel_width, self.stride |
| 2812 | |
| 2813 | # use pad1D here in order to correctly handle self.pad = 'causal', |
| 2814 | # which isn't defined for pad2D |
| 2815 | _, p = pad1D(X, self.pad, self.kernel_width, s, d) |
| 2816 | p2D = (0, 0, p[0], p[1]) |
| 2817 | |
| 2818 | # columnize W, X, and dLdy |
| 2819 | dLdZ_col = dLdZ.transpose(3, 1, 2, 0).reshape(out_ch, -1) |
| 2820 | W_col = W2D.transpose(3, 2, 0, 1).reshape(out_ch, -1).T |
| 2821 | X_col, _ = im2col(X2D, W2D.shape, p2D, s, d) |
| 2822 | |
| 2823 | # compute gradients via matrix multiplication and reshape |
| 2824 | dB = dLdZ_col.sum(axis=1).reshape(1, 1, -1) |
| 2825 | dW = (dLdZ_col @ X_col.T).reshape(out_ch, in_ch, fr, fc).transpose(2, 3, 1, 0) |
| 2826 | |
| 2827 | # reshape columnized dX back into the same format as the input volume |
| 2828 | dX_col = W_col @ dLdZ_col |
| 2829 | dX = col2im(dX_col, X2D.shape, W2D.shape, p2D, s, d).transpose(0, 2, 3, 1) |
| 2830 | |
| 2831 | return np.squeeze(dX, axis=1), np.squeeze(dW, axis=0), dB |
| 2832 | |
| 2833 | def _backward_naive(self, dLdy, retain_grads=True): |
| 2834 | """ |