(
model_type: str,
model_id: str,
api_base: str | None = None,
api_key: str | None = None,
provider: str | None = None,
)
| 186 | |
| 187 | |
| 188 | def load_model( |
| 189 | model_type: str, |
| 190 | model_id: str, |
| 191 | api_base: str | None = None, |
| 192 | api_key: str | None = None, |
| 193 | provider: str | None = None, |
| 194 | ) -> Model: |
| 195 | if model_type == "OpenAIModel": |
| 196 | return OpenAIModel( |
| 197 | api_key=api_key or os.getenv("FIREWORKS_API_KEY"), |
| 198 | api_base=api_base or "https://api.fireworks.ai/inference/v1", |
| 199 | model_id=model_id, |
| 200 | ) |
| 201 | elif model_type == "LiteLLMModel": |
| 202 | return LiteLLMModel( |
| 203 | model_id=model_id, |
| 204 | api_key=api_key, |
| 205 | api_base=api_base, |
| 206 | ) |
| 207 | elif model_type == "TransformersModel": |
| 208 | return TransformersModel(model_id=model_id, device_map="auto") |
| 209 | elif model_type == "InferenceClientModel": |
| 210 | return InferenceClientModel( |
| 211 | model_id=model_id, |
| 212 | token=api_key or os.getenv("HF_API_KEY"), |
| 213 | provider=provider, |
| 214 | ) |
| 215 | else: |
| 216 | raise ValueError(f"Unsupported model type: {model_type}") |
| 217 | |
| 218 | |
| 219 | def run_smolagent( |
searching dependent graphs…