(
self,
target_text: str,
prompt_text: str = "",
prompt_wav_path: str = "",
reference_wav_path: str = "",
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
trim_silence_vad: bool = False,
streaming: bool = False,
streaming_prefix_len: int = 4,
)
| 461 | |
| 462 | @torch.inference_mode() |
| 463 | def _generate( |
| 464 | self, |
| 465 | target_text: str, |
| 466 | prompt_text: str = "", |
| 467 | prompt_wav_path: str = "", |
| 468 | reference_wav_path: str = "", |
| 469 | min_len: int = 2, |
| 470 | max_len: int = 2000, |
| 471 | inference_timesteps: int = 10, |
| 472 | cfg_value: float = 2.0, |
| 473 | retry_badcase: bool = False, |
| 474 | retry_badcase_max_times: int = 3, |
| 475 | retry_badcase_ratio_threshold: float = 6.0, |
| 476 | trim_silence_vad: bool = False, |
| 477 | streaming: bool = False, |
| 478 | streaming_prefix_len: int = 4, |
| 479 | ) -> Generator[torch.Tensor, None, None]: |
| 480 | if retry_badcase and streaming: |
| 481 | warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") |
| 482 | retry_badcase = False |
| 483 | |
| 484 | if reference_wav_path and prompt_wav_path: |
| 485 | # Combined mode: reference isolation prefix + continuation suffix |
| 486 | text = prompt_text + target_text |
| 487 | text_token = torch.LongTensor(self.text_tokenizer(text)) |
| 488 | text_token = torch.cat( |
| 489 | [ |
| 490 | text_token, |
| 491 | torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device), |
| 492 | ], |
| 493 | dim=-1, |
| 494 | ) |
| 495 | text_length = text_token.shape[0] |
| 496 | |
| 497 | ref_feat = self._encode_wav( |
| 498 | reference_wav_path, |
| 499 | padding_mode="right", |
| 500 | trim_silence_vad=trim_silence_vad, |
| 501 | ) |
| 502 | prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad) |
| 503 | prompt_audio_length = prompt_feat.size(0) |
| 504 | |
| 505 | ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device) |
| 506 | |
| 507 | prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device) |
| 508 | text_pad_feat = torch.zeros( |
| 509 | (text_length, self.patch_size, self.audio_vae.latent_dim), |
| 510 | dtype=torch.float32, |
| 511 | device=text_token.device, |
| 512 | ) |
| 513 | |
| 514 | text_token = torch.cat([ref_tokens, text_token, prompt_pad_token]) |
| 515 | audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_feat], dim=0) |
| 516 | text_mask = torch.cat( |
| 517 | [ |
| 518 | ref_t_mask, |
| 519 | torch.ones(text_length, dtype=torch.int32).to(text_token.device), |
| 520 | torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device), |
no test coverage detected