Actual computation of gradient of the loss wrt. X, W, and b
(self, dLdy, X, Z)
| 3091 | return dX[0] if len(X) == 1 else dX |
| 3092 | |
| 3093 | def _bwd(self, dLdy, X, Z): |
| 3094 | """Actual computation of gradient of the loss wrt. X, W, and b""" |
| 3095 | W = self.parameters["W"] |
| 3096 | |
| 3097 | d = self.dilation |
| 3098 | fr, fc, in_ch, out_ch = W.shape |
| 3099 | n_ex, out_rows, out_cols, out_ch = dLdy.shape |
| 3100 | (fr, fc), s, p = self.kernel_shape, self.stride, self.pad |
| 3101 | |
| 3102 | # columnize W, X, and dLdy |
| 3103 | dLdZ = dLdy * self.act_fn.grad(Z) |
| 3104 | dLdZ_col = dLdZ.transpose(3, 1, 2, 0).reshape(out_ch, -1) |
| 3105 | W_col = W.transpose(3, 2, 0, 1).reshape(out_ch, -1).T |
| 3106 | X_col, p = im2col(X, W.shape, p, s, d) |
| 3107 | |
| 3108 | # compute gradients via matrix multiplication and reshape |
| 3109 | dB = dLdZ_col.sum(axis=1).reshape(1, 1, 1, -1) |
| 3110 | dW = (dLdZ_col @ X_col.T).reshape(out_ch, in_ch, fr, fc).transpose(2, 3, 1, 0) |
| 3111 | |
| 3112 | # reshape columnized dX back into the same format as the input volume |
| 3113 | dX_col = W_col @ dLdZ_col |
| 3114 | dX = col2im(dX_col, X.shape, W.shape, p, s, d).transpose(0, 2, 3, 1) |
| 3115 | |
| 3116 | return dX, dW, dB |
| 3117 | |
| 3118 | def _backward_naive(self, dLdy, retain_grads=True): |
| 3119 | """ |