| 745 | |
| 746 | ############### Flatten ################## |
| 747 | class Flatten(LayerBase): |
| 748 | |
| 749 | def __init__(self, keep_dim="first", optimizer=None): |
| 750 | """ |
| 751 | 将多维输入展开 |
| 752 | |
| 753 | 参数说明: |
| 754 | keep_dim:展开形状,str (default : 'first') |
| 755 | 对于输入 X,keep_dim可选 'first'->将 X 重构为(X.shape[0], -1), |
| 756 | 'last'->将 X 重构为(-1, X.shape[0]),'none'->将 X 重构为(1,-1) |
| 757 | optimizer:优化方法 |
| 758 | """ |
| 759 | super().__init__(optimizer) |
| 760 | |
| 761 | self.keep_dim = keep_dim |
| 762 | self._init_params() |
| 763 | |
| 764 | def _init_params(self): |
| 765 | self.X = [] |
| 766 | self.gradients = {} |
| 767 | self.params = {} |
| 768 | self.derived_variables = {"in_dims": []} |
| 769 | |
| 770 | def forward(self, X, retain_derived=True): |
| 771 | """ |
| 772 | 前向传播 |
| 773 | |
| 774 | 参数说明: |
| 775 | X:输入数组 |
| 776 | retain_derived:是否保留中间变量,以便反向传播时再次使用,bool型 |
| 777 | """ |
| 778 | if retain_derived: |
| 779 | self.derived_variables["in_dims"].append(X.shape) |
| 780 | if self.keep_dim == "none": |
| 781 | return X.flatten().reshape(1, -1) |
| 782 | rs = (X.shape[0], -1) if self.keep_dim == "first" else (-1, X.shape[-1]) |
| 783 | return X.reshape(*rs) |
| 784 | |
| 785 | def backward(self, dLdy, retain_grads=True): |
| 786 | """ |
| 787 | 反向传播 |
| 788 | |
| 789 | 参数说明: |
| 790 | dLdy:关于损失的梯度 |
| 791 | retain_grads:是否计算中间变量的参数梯度,bool型 |
| 792 | |
| 793 | 输出说明: |
| 794 | dX:将对输入的梯度进行重构为原始输入的形状 |
| 795 | """ |
| 796 | if not isinstance(dLdy, list): |
| 797 | dLdy = [dLdy] |
| 798 | in_dims = self.derived_variables["in_dims"] |
| 799 | dX = [dy.reshape(*dims) for dy, dims in zip(dLdy, in_dims)] |
| 800 | return dX[0] if len(dLdy) == 1 else dX |
| 801 | |
| 802 | @property |
| 803 | def hyperparams(self): |
| 804 | return { |
no outgoing calls
no test coverage detected