(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None)
| 253 | |
| 254 | |
| 255 | def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None): |
| 256 | # 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model |
| 257 | if current_model is None: |
| 258 | # 优先使用用户指定的预训练模型路径 |
| 259 | base_model_path = pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path |
| 260 | |
| 261 | # 如果选择了 LoRA,尝试从其 config 读取 base_model |
| 262 | if lora_selection and lora_selection != "None": |
| 263 | full_lora_path = os.path.join("lora", lora_selection) |
| 264 | lora_config_file = os.path.join(full_lora_path, "lora_config.json") |
| 265 | |
| 266 | if os.path.exists(lora_config_file): |
| 267 | try: |
| 268 | with open(lora_config_file, "r", encoding="utf-8") as f: |
| 269 | lora_info = json.load(f) |
| 270 | saved_base_model = lora_info.get("base_model") |
| 271 | |
| 272 | if saved_base_model: |
| 273 | # 优先使用保存的 base_model 路径 |
| 274 | if os.path.exists(saved_base_model): |
| 275 | base_model_path = saved_base_model |
| 276 | print(f"Using base model from LoRA config: {base_model_path}", file=sys.stderr) |
| 277 | else: |
| 278 | print(f"Warning: Saved base_model path not found: {saved_base_model}", file=sys.stderr) |
| 279 | print(f"Falling back to default: {base_model_path}", file=sys.stderr) |
| 280 | except Exception as e: |
| 281 | print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr) |
| 282 | |
| 283 | # 加载模型 |
| 284 | lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None |
| 285 | try: |
| 286 | print(f"Loading base model: {base_model_path}", file=sys.stderr) |
| 287 | load_model(base_model_path, lora_to_load) |
| 288 | if lora_to_load: |
| 289 | print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr) |
| 290 | except Exception as e: |
| 291 | error_msg = f"Failed to load model from {base_model_path}: {str(e)}" |
| 292 | print(error_msg, file=sys.stderr) |
| 293 | return None, error_msg |
| 294 | lora_just_loaded = lora_to_load |
| 295 | else: |
| 296 | lora_just_loaded = None |
| 297 | |
| 298 | # Handle LoRA hot-swapping |
| 299 | assert current_model is not None, "Model must be loaded before inference" |
| 300 | if lora_selection and lora_selection != "None": |
| 301 | full_lora_path = os.path.join("lora", lora_selection) |
| 302 | |
| 303 | if lora_just_loaded != lora_selection: |
| 304 | new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path) |
| 305 | current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None |
| 306 | new_r = new_lora_config.r if new_lora_config else None |
| 307 | |
| 308 | if new_r is not None and current_r is not None and new_r != current_r: |
| 309 | print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr) |
| 310 | reload_base = ( |
| 311 | new_base_model if new_base_model and os.path.exists(new_base_model) |
| 312 | else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path) |
nothing calls this directly
no test coverage detected
searching dependent graphs…