MCPcopy
hub / github.com/microsoft/VibeVoice / load

Method load

demo/web/app.py:70–126  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

68 self._torch_device = torch.device(device)
69
70 def load(self) -> None:
71 print(f"[startup] Loading processor from {self.model_path}")
72 self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path)
73
74
75 # Decide dtype & attention
76 if self.device == "mps":
77 load_dtype = torch.float32
78 device_map = None
79 attn_impl_primary = "sdpa"
80 elif self.device == "cuda":
81 load_dtype = torch.bfloat16
82 device_map = 'cuda'
83 attn_impl_primary = "flash_attention_2"
84 else:
85 load_dtype = torch.float32
86 device_map = 'cpu'
87 attn_impl_primary = "sdpa"
88 print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
89 # Load model
90 try:
91 self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
92 self.model_path,
93 torch_dtype=load_dtype,
94 device_map=device_map,
95 attn_implementation=attn_impl_primary,
96 )
97
98 if self.device == "mps":
99 self.model.to("mps")
100 except Exception as e:
101 if attn_impl_primary == 'flash_attention_2':
102 print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.")
103
104 self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
105 self.model_path,
106 torch_dtype=load_dtype,
107 device_map=self.device,
108 attn_implementation='sdpa',
109 )
110 print("Load model with SDPA successfully ")
111 else:
112 raise e
113
114 self.model.eval()
115
116 self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
117 self.model.model.noise_scheduler.config,
118 algorithm_type="sde-dpmsolver++",
119 beta_schedule="squaredcos_cap_v2",
120 )
121 self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
122
123 self.voice_presets = self._load_voice_presets()
124 preset_name = os.environ.get("VOICE_PRESET")
125 self.default_voice_key = self._determine_voice_key(preset_name)
126 self._ensure_voice_cached(self.default_voice_key)
127

Callers 14

_startupFunction · 0.95
from_pretrainedMethod · 0.80
from_pretrainedMethod · 0.80
from_pretrainedMethod · 0.80
_load_audio_from_pathMethod · 0.80
_load_samplesMethod · 0.80
mainFunction · 0.80
_ensure_voice_cachedMethod · 0.80
get_feature_extractorMethod · 0.80
patch_tokenizer_configFunction · 0.80
patch_tokenizer_jsonFunction · 0.80

Calls 6

_load_voice_presetsMethod · 0.95
_determine_voice_keyMethod · 0.95
_ensure_voice_cachedMethod · 0.95
getMethod · 0.80
from_pretrainedMethod · 0.45

Tested by

no test coverage detected