全连接网络的前向传播,原理见上文 反向传播算法 部分。 参数说明: X:输入数组,为(n_samples, n_in),float型 retain_derived:是否保留中间变量,以便反向传播时再次使用,bool型
(self, X, retain_derived=True)
| 102 | self.is_initialized = True |
| 103 | |
| 104 | def forward(self, X, retain_derived=True): |
| 105 | """ |
| 106 | 全连接网络的前向传播,原理见上文 反向传播算法 部分。 |
| 107 | |
| 108 | 参数说明: |
| 109 | X:输入数组,为(n_samples, n_in),float型 |
| 110 | retain_derived:是否保留中间变量,以便反向传播时再次使用,bool型 |
| 111 | """ |
| 112 | if not self.is_initialized: # 如果参数未初始化,先初始化参数 |
| 113 | self.n_in = X.shape[1] |
| 114 | self._init_params() |
| 115 | |
| 116 | W = self.params["W"] |
| 117 | b = self.params["b"] |
| 118 | z = X @ W + b |
| 119 | a = self.acti_fn.forward(z) |
| 120 | |
| 121 | if retain_derived: |
| 122 | self.X.append(X) |
| 123 | |
| 124 | return a |
| 125 | |
| 126 | def backward(self, dLda, retain_grads=True): |
| 127 | """ |
nothing calls this directly
no test coverage detected