(cls, pretrained_model_name_or_path, *args, **kwargs)
| 244 | |
| 245 | @classmethod |
| 246 | def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| 247 | train_mode = kwargs.pop("train", False) |
| 248 | load_asr = kwargs.pop("load_asr", False) |
| 249 | asr_model_name = kwargs.pop("asr_model_name", "openai/whisper-large-v3-turbo") |
| 250 | |
| 251 | # Suppress noisy INFO logs from transformers/huggingface_hub during loading |
| 252 | _prev_disable = logging.root.manager.disable |
| 253 | logging.disable(logging.INFO) |
| 254 | |
| 255 | try: |
| 256 | # Resolve to local path first; download only if not already cached |
| 257 | resolved_path = _resolve_model_path(pretrained_model_name_or_path) |
| 258 | |
| 259 | model = super().from_pretrained(resolved_path, *args, **kwargs) |
| 260 | |
| 261 | if not train_mode: |
| 262 | model.text_tokenizer = AutoTokenizer.from_pretrained(resolved_path) |
| 263 | |
| 264 | audio_tokenizer_path = os.path.join(resolved_path, "audio_tokenizer") |
| 265 | |
| 266 | if not os.path.isdir(audio_tokenizer_path): |
| 267 | audio_tokenizer_path = _resolve_model_path( |
| 268 | "eustlb/higgs-audio-v2-tokenizer" |
| 269 | ) |
| 270 | |
| 271 | # higgs-audio-v2-tokenizer does not support MPS |
| 272 | # (output channels > 65536) |
| 273 | tokenizer_device = ( |
| 274 | "cpu" if str(model.device).startswith("mps") else model.device |
| 275 | ) |
| 276 | model.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained( |
| 277 | audio_tokenizer_path, device_map=tokenizer_device |
| 278 | ) |
| 279 | model.feature_extractor = AutoFeatureExtractor.from_pretrained( |
| 280 | audio_tokenizer_path |
| 281 | ) |
| 282 | |
| 283 | model.sampling_rate = model.feature_extractor.sampling_rate |
| 284 | |
| 285 | model.duration_estimator = RuleDurationEstimator() |
| 286 | |
| 287 | if load_asr: |
| 288 | model.load_asr_model(model_name=asr_model_name) |
| 289 | finally: |
| 290 | logging.disable(_prev_disable) |
| 291 | |
| 292 | return model |
| 293 | |
| 294 | # ------------------------------------------------------------------- |
| 295 | # ASR support (optional, for auto-transcription) |
no test coverage detected