| 529 | |
| 530 | class FreqFFTCouplingBlock(nn.Module): |
| 531 | def __init__(self, in_channels, hidden_channels, n_layers, |
| 532 | gin_channels=0, p_dropout=0, sigmoid_scale=False): |
| 533 | super().__init__() |
| 534 | self.in_channels = in_channels |
| 535 | self.hidden_channels = hidden_channels |
| 536 | self.n_layers = n_layers |
| 537 | self.gin_channels = gin_channels |
| 538 | self.p_dropout = p_dropout |
| 539 | self.sigmoid_scale = sigmoid_scale |
| 540 | |
| 541 | hs = hidden_channels |
| 542 | stride = 8 |
| 543 | self.start = torch.nn.Conv2d(3, hs, kernel_size=stride * 2, |
| 544 | stride=stride, padding=stride // 2) |
| 545 | end = nn.ConvTranspose2d(hs, 2, kernel_size=stride, stride=stride) |
| 546 | end.weight.data.zero_() |
| 547 | end.bias.data.zero_() |
| 548 | self.end = nn.Sequential( |
| 549 | nn.Conv2d(hs * 3, hs, 3, 1, 1), |
| 550 | nn.ReLU(), |
| 551 | nn.GroupNorm(4, hs), |
| 552 | nn.Conv2d(hs, hs, 3, 1, 1), |
| 553 | end |
| 554 | ) |
| 555 | self.fft_v = FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers) |
| 556 | self.fft_h = nn.Sequential( |
| 557 | nn.Conv1d(hs, hs, 3, 1, 1), |
| 558 | nn.ReLU(), |
| 559 | nn.Conv1d(hs, hs, 3, 1, 1), |
| 560 | ) |
| 561 | self.fft_g = nn.Sequential( |
| 562 | nn.Conv1d( |
| 563 | gin_channels - 160, hs, kernel_size=stride * 2, stride=stride, padding=stride // 2), |
| 564 | Permute(0, 2, 1), |
| 565 | FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers), |
| 566 | Permute(0, 2, 1), |
| 567 | ) |
| 568 | |
| 569 | def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): |
| 570 | g_, _ = utils.unsqueeze(g) |