| 426 | self.wn.res_skip_layers = wn.res_skip_layers |
| 427 | |
| 428 | def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): |
| 429 | if x_mask is None: |
| 430 | x_mask = 1 |
| 431 | x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] |
| 432 | |
| 433 | x = self.start(x_0) * x_mask |
| 434 | x = self.wn(x, x_mask, g) |
| 435 | out = self.end(x) |
| 436 | |
| 437 | z_0 = x_0 |
| 438 | m = out[:, :self.in_channels // 2, :] |
| 439 | logs = out[:, self.in_channels // 2:, :] |
| 440 | if self.sigmoid_scale: |
| 441 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) |
| 442 | if reverse: |
| 443 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask |
| 444 | logdet = torch.sum(-logs * x_mask, [1, 2]) |
| 445 | else: |
| 446 | z_1 = (m + torch.exp(logs) * x_1) * x_mask |
| 447 | logdet = torch.sum(logs * x_mask, [1, 2]) |
| 448 | z = torch.cat([z_0, z_1], 1) |
| 449 | return z, logdet |
| 450 | |
| 451 | def store_inverse(self): |
| 452 | self.wn.remove_weight_norm() |