MCPcopy Index your code
hub / github.com/MingchaoZhu/DeepLearning / Flatten

Class Flatten

code/chapter9.py:747–811  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

745
746############### Flatten ##################
747class Flatten(LayerBase):
748
749 def __init__(self, keep_dim="first", optimizer=None):
750 """
751 将多维输入展开
752
753 参数说明:
754 keep_dim:展开形状,str (default : 'first')
755 对于输入 X,keep_dim可选 'first'->将 X 重构为(X.shape[0], -1),
756 'last'->将 X 重构为(-1, X.shape[0]),'none'->将 X 重构为(1,-1)
757 optimizer:优化方法
758 """
759 super().__init__(optimizer)
760
761 self.keep_dim = keep_dim
762 self._init_params()
763
764 def _init_params(self):
765 self.X = []
766 self.gradients = {}
767 self.params = {}
768 self.derived_variables = {"in_dims": []}
769
770 def forward(self, X, retain_derived=True):
771 """
772 前向传播
773
774 参数说明:
775 X:输入数组
776 retain_derived:是否保留中间变量,以便反向传播时再次使用,bool型
777 """
778 if retain_derived:
779 self.derived_variables["in_dims"].append(X.shape)
780 if self.keep_dim == "none":
781 return X.flatten().reshape(1, -1)
782 rs = (X.shape[0], -1) if self.keep_dim == "first" else (-1, X.shape[-1])
783 return X.reshape(*rs)
784
785 def backward(self, dLdy, retain_grads=True):
786 """
787 反向传播
788
789 参数说明:
790 dLdy:关于损失的梯度
791 retain_grads:是否计算中间变量的参数梯度,bool型
792
793 输出说明:
794 dX:将对输入的梯度进行重构为原始输入的形状
795 """
796 if not isinstance(dLdy, list):
797 dLdy = [dLdy]
798 in_dims = self.derived_variables["in_dims"]
799 dX = [dy.reshape(*dims) for dy, dims in zip(dLdy, in_dims)]
800 return dX[0] if len(dLdy) == 1 else dX
801
802 @property
803 def hyperparams(self):
804 return {

Callers 2

_set_paramsMethod · 0.85
_set_paramsMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected