Wrapper for a provider configuration and client
| 14 | logger = logging.getLogger(__name__) |
| 15 | |
| 16 | class Provider: |
| 17 | """Wrapper for a provider configuration and client""" |
| 18 | def __init__(self, config: Dict): |
| 19 | self.name = config['name'] |
| 20 | self.base_url = config['base_url'] |
| 21 | self.api_key = config['api_key'] |
| 22 | self.weight = config.get('weight', 1) |
| 23 | self.fallback_only = config.get('fallback_only', False) |
| 24 | self.model_map = config.get('model_map', {}) |
| 25 | self._client = None |
| 26 | self.is_healthy = True |
| 27 | self.last_error = None |
| 28 | self.latencies = [] # Track recent latencies |
| 29 | |
| 30 | # Per-provider concurrency control |
| 31 | self.max_concurrent = config.get('max_concurrent', None) # None means no limit |
| 32 | if self.max_concurrent is not None: |
| 33 | self._semaphore = threading.Semaphore(self.max_concurrent) |
| 34 | logger.info(f"Provider {self.name} limited to {self.max_concurrent} concurrent requests") |
| 35 | else: |
| 36 | self._semaphore = None |
| 37 | |
| 38 | @property |
| 39 | def client(self): |
| 40 | """Lazy initialization of OpenAI client""" |
| 41 | if not self._client: |
| 42 | if 'azure' in self.base_url.lower(): |
| 43 | # Handle Azure OpenAI |
| 44 | self._client = AzureOpenAI( |
| 45 | api_key=self.api_key, |
| 46 | azure_endpoint=self.base_url, |
| 47 | api_version="2024-02-01", |
| 48 | max_retries=0 # Disable client retries - we handle them |
| 49 | ) |
| 50 | elif 'generativelanguage.googleapis.com' in self.base_url: |
| 51 | # Google AI with standard OpenAI-compatible client |
| 52 | self._client = OpenAI( |
| 53 | api_key=self.api_key, |
| 54 | base_url=self.base_url, |
| 55 | max_retries=0 # Disable client retries - we handle them |
| 56 | ) |
| 57 | else: |
| 58 | # Standard OpenAI-compatible client |
| 59 | self._client = OpenAI( |
| 60 | api_key=self.api_key, |
| 61 | base_url=self.base_url, |
| 62 | max_retries=0 # Disable client retries - we handle them |
| 63 | ) |
| 64 | return self._client |
| 65 | |
| 66 | def map_model(self, model: str) -> str: |
| 67 | """Map requested model to provider-specific name""" |
| 68 | return self.model_map.get(model, model) |
| 69 | |
| 70 | def track_latency(self, latency: float): |
| 71 | """Track request latency""" |
| 72 | self.latencies.append(latency) |
| 73 | if len(self.latencies) > 10: |