Load a Whisper ASR model for reference audio transcription. Args: model_name: HuggingFace model name or local path for the Whisper model.
(self, model_name: str = "openai/whisper-large-v3-turbo")
| 296 | # ------------------------------------------------------------------- |
| 297 | |
| 298 | def load_asr_model(self, model_name: str = "openai/whisper-large-v3-turbo"): |
| 299 | """Load a Whisper ASR model for reference audio transcription. |
| 300 | |
| 301 | Args: |
| 302 | model_name: HuggingFace model name or local path for the Whisper model. |
| 303 | """ |
| 304 | from transformers import pipeline as hf_pipeline |
| 305 | |
| 306 | logger.info("Loading ASR model %s ...", model_name) |
| 307 | asr_dtype = ( |
| 308 | torch.float16 if str(self.device).startswith("cuda") else torch.float32 |
| 309 | ) |
| 310 | |
| 311 | model_name = _resolve_model_path(model_name) |
| 312 | |
| 313 | self._asr_pipe = hf_pipeline( |
| 314 | "automatic-speech-recognition", |
| 315 | model=model_name, |
| 316 | dtype=asr_dtype, |
| 317 | device_map=self.device, |
| 318 | ) |
| 319 | logger.info("ASR model loaded on %s.", self.device) |
| 320 | |
| 321 | @torch.inference_mode() |
| 322 | def transcribe( |
no test coverage detected