(self)
| 68 | self._torch_device = torch.device(device) |
| 69 | |
| 70 | def load(self) -> None: |
| 71 | print(f"[startup] Loading processor from {self.model_path}") |
| 72 | self.processor = VibeVoiceStreamingProcessor.from_pretrained(self.model_path) |
| 73 | |
| 74 | |
| 75 | # Decide dtype & attention |
| 76 | if self.device == "mps": |
| 77 | load_dtype = torch.float32 |
| 78 | device_map = None |
| 79 | attn_impl_primary = "sdpa" |
| 80 | elif self.device == "cuda": |
| 81 | load_dtype = torch.bfloat16 |
| 82 | device_map = 'cuda' |
| 83 | attn_impl_primary = "flash_attention_2" |
| 84 | else: |
| 85 | load_dtype = torch.float32 |
| 86 | device_map = 'cpu' |
| 87 | attn_impl_primary = "sdpa" |
| 88 | print(f"Using device: {device_map}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") |
| 89 | # Load model |
| 90 | try: |
| 91 | self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
| 92 | self.model_path, |
| 93 | torch_dtype=load_dtype, |
| 94 | device_map=device_map, |
| 95 | attn_implementation=attn_impl_primary, |
| 96 | ) |
| 97 | |
| 98 | if self.device == "mps": |
| 99 | self.model.to("mps") |
| 100 | except Exception as e: |
| 101 | if attn_impl_primary == 'flash_attention_2': |
| 102 | print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.") |
| 103 | |
| 104 | self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( |
| 105 | self.model_path, |
| 106 | torch_dtype=load_dtype, |
| 107 | device_map=self.device, |
| 108 | attn_implementation='sdpa', |
| 109 | ) |
| 110 | print("Load model with SDPA successfully ") |
| 111 | else: |
| 112 | raise e |
| 113 | |
| 114 | self.model.eval() |
| 115 | |
| 116 | self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config( |
| 117 | self.model.model.noise_scheduler.config, |
| 118 | algorithm_type="sde-dpmsolver++", |
| 119 | beta_schedule="squaredcos_cap_v2", |
| 120 | ) |
| 121 | self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) |
| 122 | |
| 123 | self.voice_presets = self._load_voice_presets() |
| 124 | preset_name = os.environ.get("VOICE_PRESET") |
| 125 | self.default_voice_key = self._determine_voice_key(preset_name) |
| 126 | self._ensure_voice_cached(self.default_voice_key) |
| 127 |
no test coverage detected