| 405 | upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex) |
| 406 | |
| 407 | def forward(self, asr, F0_curve, N, s): |
| 408 | F0 = self.F0_conv(F0_curve.unsqueeze(1)) |
| 409 | N = self.N_conv(N.unsqueeze(1)) |
| 410 | x = torch.cat([asr, F0, N], axis=1) |
| 411 | x = self.encode(x, s) |
| 412 | asr_res = self.asr_res(asr) |
| 413 | res = True |
| 414 | for block in self.decode: |
| 415 | if res: |
| 416 | x = torch.cat([x, asr_res, F0, N], axis=1) |
| 417 | x = block(x, s) |
| 418 | if block.upsample_type != "none": |
| 419 | res = False |
| 420 | x = self.generator(x, s, F0_curve) |
| 421 | return x |