MCPcopy
hub / github.com/OpenBMB/VoxCPM / _generate

Method _generate

src/voxcpm/model/voxcpm2.py:463–678  ·  view source on GitHub ↗
(
        self,
        target_text: str,
        prompt_text: str = "",
        prompt_wav_path: str = "",
        reference_wav_path: str = "",
        min_len: int = 2,
        max_len: int = 2000,
        inference_timesteps: int = 10,
        cfg_value: float = 2.0,
        retry_badcase: bool = False,
        retry_badcase_max_times: int = 3,
        retry_badcase_ratio_threshold: float = 6.0,
        trim_silence_vad: bool = False,
        streaming: bool = False,
        streaming_prefix_len: int = 4,
    )

Source from the content-addressed store, hash-verified

461
462 @torch.inference_mode()
463 def _generate(
464 self,
465 target_text: str,
466 prompt_text: str = "",
467 prompt_wav_path: str = "",
468 reference_wav_path: str = "",
469 min_len: int = 2,
470 max_len: int = 2000,
471 inference_timesteps: int = 10,
472 cfg_value: float = 2.0,
473 retry_badcase: bool = False,
474 retry_badcase_max_times: int = 3,
475 retry_badcase_ratio_threshold: float = 6.0,
476 trim_silence_vad: bool = False,
477 streaming: bool = False,
478 streaming_prefix_len: int = 4,
479 ) -> Generator[torch.Tensor, None, None]:
480 if retry_badcase and streaming:
481 warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
482 retry_badcase = False
483
484 if reference_wav_path and prompt_wav_path:
485 # Combined mode: reference isolation prefix + continuation suffix
486 text = prompt_text + target_text
487 text_token = torch.LongTensor(self.text_tokenizer(text))
488 text_token = torch.cat(
489 [
490 text_token,
491 torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
492 ],
493 dim=-1,
494 )
495 text_length = text_token.shape[0]
496
497 ref_feat = self._encode_wav(
498 reference_wav_path,
499 padding_mode="right",
500 trim_silence_vad=trim_silence_vad,
501 )
502 prompt_feat = self._encode_wav(prompt_wav_path, padding_mode="left", trim_silence_vad=trim_silence_vad)
503 prompt_audio_length = prompt_feat.size(0)
504
505 ref_tokens, ref_feats, ref_t_mask, ref_a_mask = self._make_ref_prefix(ref_feat, text_token.device)
506
507 prompt_pad_token = torch.zeros(prompt_audio_length, dtype=torch.int32, device=text_token.device)
508 text_pad_feat = torch.zeros(
509 (text_length, self.patch_size, self.audio_vae.latent_dim),
510 dtype=torch.float32,
511 device=text_token.device,
512 )
513
514 text_token = torch.cat([ref_tokens, text_token, prompt_pad_token])
515 audio_feat = torch.cat([ref_feats, text_pad_feat, prompt_feat], dim=0)
516 text_mask = torch.cat(
517 [
518 ref_t_mask,
519 torch.ones(text_length, dtype=torch.int32).to(text_token.device),
520 torch.zeros(prompt_audio_length, dtype=torch.int32).to(text_token.device),

Callers 2

generateMethod · 0.95
generate_streamingMethod · 0.95

Calls 8

_encode_wavMethod · 0.95
_make_ref_prefixMethod · 0.95
_inferenceMethod · 0.95
get_dtypeFunction · 0.85
next_and_closeFunction · 0.85
streaming_decodeMethod · 0.80
decode_chunkMethod · 0.80
decodeMethod · 0.45

Tested by

no test coverage detected