MCPcopy
hub / github.com/OpenBMB/VoxCPM / _generate

Method _generate

src/voxcpm/model/voxcpm.py:357–493  ·  view source on GitHub ↗
(
        self,
        target_text: str,
        prompt_text: str = "",
        prompt_wav_path: str = "",
        min_len: int = 2,
        max_len: int = 2000,
        inference_timesteps: int = 10,
        cfg_value: float = 2.0,
        retry_badcase: bool = False,
        retry_badcase_max_times: int = 3,
        retry_badcase_ratio_threshold: float = 6.0,  # setting acceptable ratio of audio length to text length (for badcase detection)
        streaming: bool = False,
    )

Source from the content-addressed store, hash-verified

355
356 @torch.inference_mode()
357 def _generate(
358 self,
359 target_text: str,
360 prompt_text: str = "",
361 prompt_wav_path: str = "",
362 min_len: int = 2,
363 max_len: int = 2000,
364 inference_timesteps: int = 10,
365 cfg_value: float = 2.0,
366 retry_badcase: bool = False,
367 retry_badcase_max_times: int = 3,
368 retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
369 streaming: bool = False,
370 ) -> Generator[torch.Tensor, None, None]:
371 if retry_badcase and streaming:
372 warnings.warn("Retry on bad cases is not supported in streaming mode, setting retry_badcase=False.")
373 retry_badcase = False
374 if len(prompt_wav_path) == 0:
375 text = target_text
376 text_token = torch.LongTensor(self.text_tokenizer(text))
377 text_token = torch.cat(
378 [
379 text_token,
380 torch.tensor(
381 [self.audio_start_token],
382 dtype=torch.int32,
383 device=text_token.device,
384 ),
385 ],
386 dim=-1,
387 )
388 text_length = text_token.shape[0]
389
390 audio_feat = torch.zeros(
391 (text_length, self.patch_size, self.audio_vae.latent_dim),
392 dtype=torch.float32,
393 device=text_token.device,
394 )
395 text_mask = torch.ones(text_length).type(torch.int32).to(text_token.device)
396 audio_mask = torch.zeros(text_length).type(torch.int32).to(text_token.device)
397
398 else:
399 text = prompt_text + target_text
400 text_token = torch.LongTensor(self.text_tokenizer(text))
401 text_token = torch.cat(
402 [
403 text_token,
404 torch.tensor([self.audio_start_token], dtype=torch.int32, device=text_token.device),
405 ],
406 dim=-1,
407 )
408 text_length = text_token.shape[0]
409
410 audio, sr = torchaudio.load(prompt_wav_path)
411 if audio.size(0) > 1:
412 audio = audio.mean(dim=0, keepdim=True)
413
414 if sr != self.sample_rate:

Callers 2

generateMethod · 0.95
generate_streamingMethod · 0.95

Calls 5

_inferenceMethod · 0.95
get_dtypeFunction · 0.85
next_and_closeFunction · 0.85
encodeMethod · 0.45
decodeMethod · 0.45

Tested by

no test coverage detected