加载ONNX模型 Args: model_path: 模型文件路径 Returns: ONNX推理会话对象 Raises: ModelLoadError: 当模型加载失败时
(self, model_path: str)
| 53 | print(f"GPU设置失败,回退到CPU模式: {str(e)}") |
| 54 | |
| 55 | def load_model(self, model_path: str) -> onnxruntime.InferenceSession: |
| 56 | """ |
| 57 | 加载ONNX模型 |
| 58 | |
| 59 | Args: |
| 60 | model_path: 模型文件路径 |
| 61 | |
| 62 | Returns: |
| 63 | ONNX推理会话对象 |
| 64 | |
| 65 | Raises: |
| 66 | ModelLoadError: 当模型加载失败时 |
| 67 | """ |
| 68 | try: |
| 69 | if not os.path.exists(model_path): |
| 70 | raise ModelLoadError(f"模型文件不存在: {model_path}") |
| 71 | |
| 72 | # 设置ONNX运行时日志级别 |
| 73 | onnxruntime.set_default_logger_severity(3) |
| 74 | |
| 75 | # 创建推理会话 |
| 76 | session = onnxruntime.InferenceSession(model_path, providers=self.providers) |
| 77 | |
| 78 | return session |
| 79 | |
| 80 | except Exception as e: |
| 81 | raise ModelLoadError(f"模型加载失败: {str(e)}") from e |
| 82 | |
| 83 | def get_model_info(self, session: onnxruntime.InferenceSession) -> Dict[str, Any]: |
| 84 | """ |
no test coverage detected