Create chat completion based on the model, messages, and any extra arguments.
(self, model: str, messages: list, **kwargs)
| 82 | self.client = client |
| 83 | |
| 84 | def create(self, model: str, messages: list, **kwargs): |
| 85 | """ |
| 86 | Create chat completion based on the model, messages, and any extra arguments. |
| 87 | """ |
| 88 | # Check that correct format is used |
| 89 | if ":" not in model: |
| 90 | raise ValueError( |
| 91 | f"Invalid model format. Expected 'provider:model', got '{model}'" |
| 92 | ) |
| 93 | |
| 94 | # Extract the provider key from the model identifier, e.g., "google:gemini-xx" |
| 95 | provider_key, model_name = model.split(":", 1) |
| 96 | |
| 97 | # Validate if the provider is supported |
| 98 | supported_providers = ProviderFactory.get_supported_providers() |
| 99 | if provider_key not in supported_providers: |
| 100 | raise ValueError( |
| 101 | f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. " |
| 102 | "Make sure the model string is formatted correctly as 'provider:model'." |
| 103 | ) |
| 104 | |
| 105 | # Initialize provider if not already initialized |
| 106 | if provider_key not in self.client.providers: |
| 107 | config = self.client.provider_configs.get(provider_key, {}) |
| 108 | self.client.providers[provider_key] = ProviderFactory.create_provider( |
| 109 | provider_key, config |
| 110 | ) |
| 111 | |
| 112 | provider = self.client.providers.get(provider_key) |
| 113 | if not provider: |
| 114 | raise ValueError(f"Could not load provider for '{provider_key}'.") |
| 115 | |
| 116 | # Delegate the chat completion to the correct provider's implementation |
| 117 | return provider.chat_completions_create(model_name, messages, **kwargs) |