(self, X)
| 328 | return self.Y |
| 329 | |
| 330 | def extract_grads(self, X): |
| 331 | self.forward(X) |
| 332 | self.loss = self.Y.sum() |
| 333 | self.loss.backward() |
| 334 | grads = { |
| 335 | "Xs": X, |
| 336 | "Sum": self.sum.detach().numpy(), |
| 337 | "Y": self.Y.detach().numpy(), |
| 338 | "dLdY": self.Y.grad.numpy(), |
| 339 | "dLdSum": self.sum.grad.numpy(), |
| 340 | } |
| 341 | grads.update( |
| 342 | {"dLdX{}".format(i + 1): xi.grad.numpy() for i, xi in enumerate(self.Xs)} |
| 343 | ) |
| 344 | return grads |
| 345 | |
| 346 | |
| 347 | class TorchMultiplyLayer(nn.Module): |