获取模型权重路径 Args: model_name: 模型名称 create_if_missing: 如果目录不存在是否创建 Returns: 模型权重路径,找不到返回None
(self, model_name: str, create_if_missing: bool = True)
| 61 | return str(project_root / 'weights') |
| 62 | |
| 63 | def get_model_path(self, model_name: str, create_if_missing: bool = True) -> Optional[str]: |
| 64 | """ |
| 65 | 获取模型权重路径 |
| 66 | |
| 67 | Args: |
| 68 | model_name: 模型名称 |
| 69 | create_if_missing: 如果目录不存在是否创建 |
| 70 | |
| 71 | Returns: |
| 72 | 模型权重路径,找不到返回None |
| 73 | """ |
| 74 | if model_name not in self.model_paths: |
| 75 | logger.warning(f"未知的模型名称: {model_name}") |
| 76 | # 尝试直接拼接路径 |
| 77 | model_path = self.weights_dir / model_name |
| 78 | else: |
| 79 | model_path = self.weights_dir / self.model_paths[model_name] |
| 80 | |
| 81 | if not model_path.exists(): |
| 82 | if create_if_missing: |
| 83 | try: |
| 84 | model_path.mkdir(parents=True, exist_ok=True) |
| 85 | logger.info(f"创建模型目录: {model_path}") |
| 86 | except Exception as e: |
| 87 | logger.error(f"创建模型目录失败: {e}") |
| 88 | return None |
| 89 | else: |
| 90 | logger.warning(f"模型路径不存在: {model_path}") |
| 91 | return None |
| 92 | |
| 93 | return str(model_path) |
| 94 | |
| 95 | def list_available_models(self) -> Dict[str, str]: |
| 96 | """ |
no test coverage detected