(self, Xs)
| 304 | self.act_fn = act_fn |
| 305 | |
| 306 | def forward(self, Xs): |
| 307 | self.Xs = [] |
| 308 | x = Xs[0].copy() |
| 309 | if not isinstance(x, torch.Tensor): |
| 310 | x = torchify(x) |
| 311 | |
| 312 | self.sum = x.clone() |
| 313 | x.retain_grad() |
| 314 | self.Xs.append(x) |
| 315 | |
| 316 | for i in range(1, len(Xs)): |
| 317 | x = Xs[i] |
| 318 | if not isinstance(x, torch.Tensor): |
| 319 | x = torchify(x) |
| 320 | |
| 321 | x.retain_grad() |
| 322 | self.Xs.append(x) |
| 323 | self.sum += x |
| 324 | |
| 325 | self.sum.retain_grad() |
| 326 | self.Y = self.act_fn(self.sum) |
| 327 | self.Y.retain_grad() |
| 328 | return self.Y |
| 329 | |
| 330 | def extract_grads(self, X): |
| 331 | self.forward(X) |
no test coverage detected