(self, Y_real, Y_fake, gradInterp)
| 85 | super(TorchWGANGPLoss, self).__init__() |
| 86 | |
| 87 | def forward(self, Y_real, Y_fake, gradInterp): |
| 88 | GY_fake = Y_fake.copy() |
| 89 | self.Y_real = torchify(Y_real) |
| 90 | self.Y_fake = torchify(Y_fake) |
| 91 | self.GY_fake = torchify(GY_fake) |
| 92 | self.gradInterp = torchify(gradInterp) |
| 93 | |
| 94 | # calc grad penalty |
| 95 | norm = self.gradInterp.norm(2, dim=1) |
| 96 | self.norm1 = torch.sqrt(torch.sum(self.gradInterp.pow(2), dim=1)) |
| 97 | assert torch.allclose(norm, self.norm1) |
| 98 | |
| 99 | self.gpenalty = self.lambda_ * ((self.norm1 - 1).pow(2)).mean() |
| 100 | self.C_loss = self.Y_fake.mean() - self.Y_real.mean() + self.gpenalty |
| 101 | self.G_loss = -self.GY_fake.mean() |
| 102 | |
| 103 | def extract_grads(self, Y_real, Y_fake, gradInterp): |
| 104 | self.forward(Y_real, Y_fake, gradInterp) |
no test coverage detected