(
self,
target_text: str,
prompt_text: str = "",
prompt_wav_path: str = "",
min_len: int = 2,
max_len: int = 2000,
inference_timesteps: int = 10,
cfg_value: float = 2.0,
retry_badcase: bool = False,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
streaming: bool = False,
)
| 355 | |
| 356 | @torch.inference_mode() |
| 357 | def _generate( |
| 358 | self, |
| 359 | target_text: str, |
| 360 | prompt_text: str = "", |
| 361 | prompt_wav_path: str = "", |
| 362 | min_len: int = 2, |
| 363 | max_len: int = 2000, |
| 364 | inference_timesteps: int = 10, |
| 365 | cfg_value: float = 2.0, |
| 366 | retry_badcase: bool = False, |
| 367 | retry_badcase_max_times: int = 3, |
| 368 | retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection) |
| 369 | streaming: bool = False, |
| 370 | ) -> Generator[torch.Tensor, None, None]: |
| 371 | if retry_badcase and streaming: |
| 372 | warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.") |
| 373 | retry_badcase = False |
| 374 | if len(prompt_wav_path) == 0: |
| 375 | text = target_text |
| 376 | text_token = torch.LongTensor(self.text_tokenizer(text)) |
| 377 | text_token = torch.cat( |
| 378 | [ |
| 379 | text_token, |
| 380 | torch.tensor( |
| 381 | [self.audio_start_token], |
| 382 | dtype=torch.int32, |
| 383 | device=text_token.device, |
| 384 | ), |
| 385 | ], |
| 386 | dim=-1, |
| 387 | ) |
| 388 | text_length = text_token.shape[0] |
| 389 | |
| 390 | audio_feat = torch.zeros( |
| 391 | (text_length, self.patch_size, self.audio_vae.latent_dim), |
| 392 | dtype=torch.float32, |
| 393 | device=text_token.device, |
| 394 | ) |
| 395 | text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device) |
| 396 | audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device) |
| 397 | |
| 398 | else: |
| 399 | text = prompt_text + target_text |
| 400 | text_token = torch.LongTensor(self.text_tokenizer(text)) |
| 401 | text_token = torch.cat( |
| 402 | [ |
| 403 | text_token, |
| 404 | torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device), |
| 405 | ], |
| 406 | dim=-1, |
| 407 | ) |
| 408 | text_length = text_token.shape[0] |
| 409 | |
| 410 | audio, sr = torchaudio.load(prompt_wav_path) |
| 411 | if audio.size(0) > 1: |
| 412 | audio = audio.mean(dim=0, keepdim=True) |
| 413 | |
| 414 | if sr != self.sample_rate: |
no test coverage detected