MCPcopy
hub / github.com/ddbourgin/numpy-ml / forward

Method forward

numpy_ml/tests/nn_torch_models.py:87–101  ·  view source on GitHub ↗
(self, Y_real, Y_fake, gradInterp)

Source from the content-addressed store, hash-verified

85 super(TorchWGANGPLoss, self).__init__()
86
87 def forward(self, Y_real, Y_fake, gradInterp):
88 GY_fake = Y_fake.copy()
89 self.Y_real = torchify(Y_real)
90 self.Y_fake = torchify(Y_fake)
91 self.GY_fake = torchify(GY_fake)
92 self.gradInterp = torchify(gradInterp)
93
94 # calc grad penalty
95 norm = self.gradInterp.norm(2, dim=1)
96 self.norm1 = torch.sqrt(torch.sum(self.gradInterp.pow(2), dim=1))
97 assert torch.allclose(norm, self.norm1)
98
99 self.gpenalty = self.lambda_ * ((self.norm1 - 1).pow(2)).mean()
100 self.C_loss = self.Y_fake.mean() - self.Y_real.mean() + self.gpenalty
101 self.G_loss = -self.GY_fake.mean()
102
103 def extract_grads(self, Y_real, Y_fake, gradInterp):
104 self.forward(Y_real, Y_fake, gradInterp)

Callers 1

extract_gradsMethod · 0.95

Calls 2

torchifyFunction · 0.85
copyMethod · 0.45

Tested by

no test coverage detected