| 567 | ) |
| 568 | |
| 569 | def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): |
| 570 | g_, _ = utils.unsqueeze(g) |
| 571 | g_mel = g_[:, :80] |
| 572 | g_txt = g_[:, 80:] |
| 573 | g_mel, _ = utils.squeeze(g_mel) |
| 574 | g_txt, _ = utils.squeeze(g_txt) # [B, C, T] |
| 575 | |
| 576 | if x_mask is None: |
| 577 | x_mask = 1 |
| 578 | x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] |
| 579 | x = torch.stack([x_0, g_mel[:, :80], g_mel[:, 80:]], 1) |
| 580 | x = self.start(x) # [B, C, N_bins, T] |
| 581 | B, C, N_bins, T = x.shape |
| 582 | |
| 583 | x_v = self.fft_v(x.permute(0, 3, 2, 1).reshape(B * T, N_bins, C)) |
| 584 | x_v = x_v.reshape(B, T, N_bins, -1).permute(0, 3, 2, 1) |
| 585 | # x_v = x |
| 586 | |
| 587 | x_h = self.fft_h(x.permute(0, 2, 1, 3).reshape(B * N_bins, C, T)) |
| 588 | x_h = x_h.reshape(B, N_bins, -1, T).permute(0, 2, 1, 3) |
| 589 | # x_h = x |
| 590 | |
| 591 | x_g = self.fft_g(g_txt)[:, :, None, :].repeat(1, 1, 10, 1) |
| 592 | x = torch.cat([x_v, x_h, x_g], 1) |
| 593 | out = self.end(x) |
| 594 | |
| 595 | z_0 = x_0 |
| 596 | m = out[:, 0] |
| 597 | logs = out[:, 1] |
| 598 | if self.sigmoid_scale: |
| 599 | logs = torch.log(1e-6 + torch.sigmoid(logs + 2)) |
| 600 | if reverse: |
| 601 | z_1 = (x_1 - m) * torch.exp(-logs) * x_mask |
| 602 | logdet = torch.sum(-logs * x_mask, [1, 2]) |
| 603 | else: |
| 604 | z_1 = (m + torch.exp(logs) * x_1) * x_mask |
| 605 | logdet = torch.sum(logs * x_mask, [1, 2]) |
| 606 | z = torch.cat([z_0, z_1], 1) |
| 607 | return z, logdet |
| 608 | |
| 609 | def store_inverse(self): |
| 610 | pass |