MCPcopy Index your code
hub / github.com/ddbourgin/numpy-ml / TorchMultiplyLayer

Class TorchMultiplyLayer

numpy_ml/tests/nn_torch_models.py:347–390  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

345
346
347class TorchMultiplyLayer(nn.Module):
348 def __init__(self, act_fn, **kwargs):
349 super(TorchMultiplyLayer, self).__init__()
350 self.act_fn = act_fn
351
352 def forward(self, Xs):
353 self.Xs = []
354 x = Xs[0].copy()
355 if not isinstance(x, torch.Tensor):
356 x = torchify(x)
357
358 self.prod = x.clone()
359 x.retain_grad()
360 self.Xs.append(x)
361
362 for i in range(1, len(Xs)):
363 x = Xs[i]
364 if not isinstance(x, torch.Tensor):
365 x = torchify(x)
366
367 x.retain_grad()
368 self.Xs.append(x)
369 self.prod *= x
370
371 self.prod.retain_grad()
372 self.Y = self.act_fn(self.prod)
373 self.Y.retain_grad()
374 return self.Y
375
376 def extract_grads(self, X):
377 self.forward(X)
378 self.loss = self.Y.sum()
379 self.loss.backward()
380 grads = {
381 "Xs": X,
382 "Prod": self.prod.detach().numpy(),
383 "Y": self.Y.detach().numpy(),
384 "dLdY": self.Y.grad.numpy(),
385 "dLdProd": self.prod.grad.numpy(),
386 }
387 grads.update(
388 {"dLdX{}".format(i + 1): xi.grad.numpy() for i, xi in enumerate(self.Xs)}
389 )
390 return grads
391
392
393class TorchSkipConnectionIdentity(nn.Module):

Callers 1

test_MultiplyLayerFunction · 0.85

Calls

no outgoing calls

Tested by 1

test_MultiplyLayerFunction · 0.68