| 180 | |
| 181 | |
| 182 | class AutoModel: |
| 183 | |
| 184 | def __init__(self, **kwargs): |
| 185 | """Initialize AutoModel with ASR model and optional sub-models. |
| 186 | |
| 187 | Args: |
| 188 | model (str): Model name (hub alias or full ID) or local path. |
| 189 | device (str): Device for inference. "cuda:0", "cpu", "mps", "npu:0". |
| 190 | Falls back to CPU if specified device is unavailable. |
| 191 | vad_model (str, optional): VAD model for long audio segmentation. |
| 192 | Enables processing of any-length audio. |
| 193 | vad_kwargs (dict, optional): VAD config, e.g. {"max_single_segment_time": 60000}. |
| 194 | punc_model (str, optional): Punctuation restoration model. |
| 195 | Not needed for Fun-ASR-Nano/SenseVoice/Qwen3-ASR (they output punctuation natively). |
| 196 | spk_model (str, optional): Speaker model for diarization ("cam++" or full model ID). |
| 197 | Requires vad_model. For Qwen3-ASR, also requires forced_aligner. |
| 198 | spk_mode (str, optional): Speaker diarization mode. "punc_segment" (default) or "vad_segment". |
| 199 | hub (str): Model hub. "ms" (ModelScope, default) or "hf" (HuggingFace). |
| 200 | ncpu (int): CPU threads (default: 4). |
| 201 | disable_update (bool): Skip version check on startup. |
| 202 | disable_pbar (bool): Disable tqdm progress bars. |
| 203 | **kwargs: Additional model-specific parameters (passed to config.yaml overrides). |
| 204 | |
| 205 | Examples: |
| 206 | >>> model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc") |
| 207 | >>> model = AutoModel(model="FunAudioLLM/Fun-ASR-Nano-2512", trust_remote_code=True, |
| 208 | ... remote_code="./model.py", vad_model="fsmn-vad", spk_model="cam++", hub="hf") |
| 209 | """ |
| 210 | try: |
| 211 | from funasr.utils.version_checker import check_for_update |
| 212 | |
| 213 | check_for_update(disable=kwargs.get("disable_update", False)) |
| 214 | except: |
| 215 | pass |
| 216 | |
| 217 | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) |
| 218 | logging.basicConfig(level=log_level) |
| 219 | |
| 220 | model, kwargs = self.build_model(**kwargs) |
| 221 | |
| 222 | # if vad_model is not None, build vad model else None |
| 223 | vad_model = kwargs.get("vad_model", None) |
| 224 | vad_kwargs = {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {}) |
| 225 | if vad_model is not None: |
| 226 | logging.info("Building VAD model.") |
| 227 | vad_kwargs["model"] = vad_model |
| 228 | vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master") |
| 229 | vad_kwargs["device"] = kwargs["device"] |
| 230 | vad_kwargs.setdefault("ncpu", kwargs.get("ncpu", 4)) |
| 231 | if "hub" in kwargs: |
| 232 | vad_kwargs.setdefault("hub", kwargs["hub"]) |
| 233 | vad_model, vad_kwargs = self.build_model(**vad_kwargs) |
| 234 | |
| 235 | # if punc_model is not None, build punc model else None |
| 236 | punc_model = kwargs.get("punc_model", None) |
| 237 | punc_kwargs = {} if kwargs.get("punc_kwargs", {}) is None else kwargs.get("punc_kwargs", {}) |
| 238 | if punc_model is not None: |
| 239 | logging.info("Building punc model.") |
no outgoing calls
searching dependent graphs…