| 345 | |
| 346 | |
| 347 | class TorchMultiplyLayer(nn.Module): |
| 348 | def __init__(self, act_fn, **kwargs): |
| 349 | super(TorchMultiplyLayer, self).__init__() |
| 350 | self.act_fn = act_fn |
| 351 | |
| 352 | def forward(self, Xs): |
| 353 | self.Xs = [] |
| 354 | x = Xs[0].copy() |
| 355 | if not isinstance(x, torch.Tensor): |
| 356 | x = torchify(x) |
| 357 | |
| 358 | self.prod = x.clone() |
| 359 | x.retain_grad() |
| 360 | self.Xs.append(x) |
| 361 | |
| 362 | for i in range(1, len(Xs)): |
| 363 | x = Xs[i] |
| 364 | if not isinstance(x, torch.Tensor): |
| 365 | x = torchify(x) |
| 366 | |
| 367 | x.retain_grad() |
| 368 | self.Xs.append(x) |
| 369 | self.prod *= x |
| 370 | |
| 371 | self.prod.retain_grad() |
| 372 | self.Y = self.act_fn(self.prod) |
| 373 | self.Y.retain_grad() |
| 374 | return self.Y |
| 375 | |
| 376 | def extract_grads(self, X): |
| 377 | self.forward(X) |
| 378 | self.loss = self.Y.sum() |
| 379 | self.loss.backward() |
| 380 | grads = { |
| 381 | "Xs": X, |
| 382 | "Prod": self.prod.detach().numpy(), |
| 383 | "Y": self.Y.detach().numpy(), |
| 384 | "dLdY": self.Y.grad.numpy(), |
| 385 | "dLdProd": self.prod.grad.numpy(), |
| 386 | } |
| 387 | grads.update( |
| 388 | {"dLdX{}".format(i + 1): xi.grad.numpy() for i, xi in enumerate(self.Xs)} |
| 389 | ) |
| 390 | return grads |
| 391 | |
| 392 | |
| 393 | class TorchSkipConnectionIdentity(nn.Module): |
no outgoing calls