MCPcopy
hub / github.com/OpenBMB/VoxCPM / run_inference

Function run_inference

lora_ft_webui.py:255–371  ·  view source on GitHub ↗
(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None)

Source from the content-addressed store, hash-verified

253
254
255def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
256 # 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
257 if current_model is None:
258 # 优先使用用户指定的预训练模型路径
259 base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path
260
261 # 如果选择了 LoRA,尝试从其 config 读取 base_model
262 if lora_selection and lora_selection != "None":
263 full_lora_path = os.path.join("lora", lora_selection)
264 lora_config_file = os.path.join(full_lora_path, "lora_config.json")
265
266 if os.path.exists(lora_config_file):
267 try:
268 with open(lora_config_file, "r", encoding="utf-8") as f:
269 lora_info = json.load(f)
270 saved_base_model = lora_info.get("base_model")
271
272 if saved_base_model:
273 # 优先使用保存的 base_model 路径
274 if os.path.exists(saved_base_model):
275 base_model_path = saved_base_model
276 print(f"Using base model from LoRA config: {base_model_path}", file=sys.stderr)
277 else:
278 print(f"Warning: Saved base_model path not found: {saved_base_model}", file=sys.stderr)
279 print(f"Falling back to default: {base_model_path}", file=sys.stderr)
280 except Exception as e:
281 print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
282
283 # 加载模型
284 lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
285 try:
286 print(f"Loading base model: {base_model_path}", file=sys.stderr)
287 load_model(base_model_path, lora_to_load)
288 if lora_to_load:
289 print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
290 except Exception as e:
291 error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
292 print(error_msg, file=sys.stderr)
293 return None, error_msg
294 lora_just_loaded = lora_to_load
295 else:
296 lora_just_loaded = None
297
298 # Handle LoRA hot-swapping
299 assert current_model is not None, "Model must be loaded before inference"
300 if lora_selection and lora_selection != "None":
301 full_lora_path = os.path.join("lora", lora_selection)
302
303 if lora_just_loaded != lora_selection:
304 new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
305 current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
306 new_r = new_lora_config.r if new_lora_config else None
307
308 if new_r is not None and current_r is not None and new_r != current_r:
309 print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
310 reload_base = (
311 new_base_model if new_base_model and os.path.exists(new_base_model)
312 else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)

Callers

nothing calls this directly

Calls 6

recognize_audioFunction · 0.85
load_loraMethod · 0.80
load_modelFunction · 0.70
set_lora_enabledMethod · 0.45
generateMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…