| 299 | |
| 300 | |
| 301 | class TorchAddLayer(nn.Module): |
| 302 | def __init__(self, act_fn, **kwargs): |
| 303 | super(TorchAddLayer, self).__init__() |
| 304 | self.act_fn = act_fn |
| 305 | |
| 306 | def forward(self, Xs): |
| 307 | self.Xs = [] |
| 308 | x = Xs[0].copy() |
| 309 | if not isinstance(x, torch.Tensor): |
| 310 | x = torchify(x) |
| 311 | |
| 312 | self.sum = x.clone() |
| 313 | x.retain_grad() |
| 314 | self.Xs.append(x) |
| 315 | |
| 316 | for i in range(1, len(Xs)): |
| 317 | x = Xs[i] |
| 318 | if not isinstance(x, torch.Tensor): |
| 319 | x = torchify(x) |
| 320 | |
| 321 | x.retain_grad() |
| 322 | self.Xs.append(x) |
| 323 | self.sum += x |
| 324 | |
| 325 | self.sum.retain_grad() |
| 326 | self.Y = self.act_fn(self.sum) |
| 327 | self.Y.retain_grad() |
| 328 | return self.Y |
| 329 | |
| 330 | def extract_grads(self, X): |
| 331 | self.forward(X) |
| 332 | self.loss = self.Y.sum() |
| 333 | self.loss.backward() |
| 334 | grads = { |
| 335 | "Xs": X, |
| 336 | "Sum": self.sum.detach().numpy(), |
| 337 | "Y": self.Y.detach().numpy(), |
| 338 | "dLdY": self.Y.grad.numpy(), |
| 339 | "dLdSum": self.sum.grad.numpy(), |
| 340 | } |
| 341 | grads.update( |
| 342 | {"dLdX{}".format(i + 1): xi.grad.numpy() for i, xi in enumerate(self.Xs)} |
| 343 | ) |
| 344 | return grads |
| 345 | |
| 346 | |
| 347 | class TorchMultiplyLayer(nn.Module): |
no outgoing calls