池化层的反向传播,原理见上文。 参数说明: dLdy:关于损失的梯度,为 (n_samples, out_rows, out_cols, out_ch) retain_grads:是否计算中间变量的参数梯度,bool型 输出说明: dXs:即dX,当前卷积层对输入关于损失的梯度,为 (n_samples, in_rows, in_cols, in_ch)
(self, dLdy, retain_grads=True)
| 676 | return Y |
| 677 | |
| 678 | def backward(self, dLdy, retain_grads=True): |
| 679 | """ |
| 680 | 池化层的反向传播,原理见上文。 |
| 681 | |
| 682 | 参数说明: |
| 683 | dLdy:关于损失的梯度,为 (n_samples, out_rows, out_cols, out_ch) |
| 684 | retain_grads:是否计算中间变量的参数梯度,bool型 |
| 685 | |
| 686 | 输出说明: |
| 687 | dXs:即dX,当前卷积层对输入关于损失的梯度,为 (n_samples, in_rows, in_cols, in_ch) |
| 688 | """ |
| 689 | if not isinstance(dLdy, list): |
| 690 | dLdy = [dLdy] |
| 691 | |
| 692 | Xs = self.X |
| 693 | out_rows = self.derived_variables["out_rows"] |
| 694 | out_cols = self.derived_variables["out_cols"] |
| 695 | |
| 696 | (fr, fc), s, p = self.kernel_shape, self.stride, self.pad |
| 697 | |
| 698 | dXs = [] |
| 699 | for X, dy, out_row, out_col in zip(Xs, dLdy, out_rows, out_cols): |
| 700 | n_samp, in_rows, in_cols, nc_in = X.shape |
| 701 | X_pad, (pr1, pr2, pc1, pc2) = pad2D(X, p, self.kernel_shape, s) |
| 702 | |
| 703 | dX = np.zeros_like(X_pad) |
| 704 | for m in range(n_samp): |
| 705 | for i in range(out_row): |
| 706 | for j in range(out_col): |
| 707 | for c in range(self.out_ch): |
| 708 | i0, i1 = i * s, (i * s) + fr |
| 709 | j0, j1 = j * s, (j * s) + fc |
| 710 | |
| 711 | if self.mode == "max": |
| 712 | xi = X[m, i0:i1, j0:j1, c] |
| 713 | mask = np.zeros_like(xi).astype(bool) |
| 714 | x, y = np.argwhere(xi == np.max(xi))[0] |
| 715 | mask[x, y] = True |
| 716 | dX[m, i0:i1, j0:j1, c] += mask * dy[m, i, j, c] |
| 717 | |
| 718 | elif self.mode == "average": |
| 719 | frame = np.ones((fr, fc)) * dy[m, i, j, c] |
| 720 | dX[m, i0:i1, j0:j1, c] += frame / np.prod((fr, fc)) |
| 721 | |
| 722 | pr2 = None if pr2 == 0 else -pr2 |
| 723 | pc2 = None if pc2 == 0 else -pc2 |
| 724 | dXs.append(dX[:, pr1:pr2, pc1:pc2, :]) |
| 725 | |
| 726 | return dXs[0] if len(Xs) == 1 else dXs |
| 727 | |
| 728 | @property |
| 729 | def hyperparams(self): |