MCPcopy
hub / github.com/k2-fsa/OmniVoice / from_pretrained

Method from_pretrained

omnivoice/models/omnivoice.py:246–292  ·  view source on GitHub ↗
(cls, pretrained_model_name_or_path, *args, **kwargs)

Source from the content-addressed store, hash-verified

244
245 @classmethod
246 def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
247 train_mode = kwargs.pop("train", False)
248 load_asr = kwargs.pop("load_asr", False)
249 asr_model_name = kwargs.pop("asr_model_name", "openai/whisper-large-v3-turbo")
250
251 # Suppress noisy INFO logs from transformers/huggingface_hub during loading
252 _prev_disable = logging.root.manager.disable
253 logging.disable(logging.INFO)
254
255 try:
256 # Resolve to local path first; download only if not already cached
257 resolved_path = _resolve_model_path(pretrained_model_name_or_path)
258
259 model = super().from_pretrained(resolved_path, *args, **kwargs)
260
261 if not train_mode:
262 model.text_tokenizer = AutoTokenizer.from_pretrained(resolved_path)
263
264 audio_tokenizer_path = os.path.join(resolved_path, "audio_tokenizer")
265
266 if not os.path.isdir(audio_tokenizer_path):
267 audio_tokenizer_path = _resolve_model_path(
268 "eustlb/higgs-audio-v2-tokenizer"
269 )
270
271 # higgs-audio-v2-tokenizer does not support MPS
272 # (output channels > 65536)
273 tokenizer_device = (
274 "cpu" if str(model.device).startswith("mps") else model.device
275 )
276 model.audio_tokenizer = HiggsAudioV2TokenizerModel.from_pretrained(
277 audio_tokenizer_path, device_map=tokenizer_device
278 )
279 model.feature_extractor = AutoFeatureExtractor.from_pretrained(
280 audio_tokenizer_path
281 )
282
283 model.sampling_rate = model.feature_extractor.sampling_rate
284
285 model.duration_estimator = RuleDurationEstimator()
286
287 if load_asr:
288 model.load_asr_model(model_name=asr_model_name)
289 finally:
290 logging.disable(_prev_disable)
291
292 return model
293
294 # -------------------------------------------------------------------
295 # ASR support (optional, for auto-transcription)

Callers 6

process_initFunction · 0.80
mainFunction · 0.80
mainFunction · 0.80
process_initFunction · 0.80
process_initFunction · 0.80

Calls 3

_resolve_model_pathFunction · 0.85
load_asr_modelMethod · 0.80

Tested by

no test coverage detected