| 29 | |
| 30 | |
| 31 | class EmbeddingClient: |
| 32 | def __init__(self, model_name: str = "text-embedding-3-small"): |
| 33 | """ |
| 34 | Initialize the EmbeddingClient. |
| 35 | |
| 36 | Args: |
| 37 | model (str): The OpenAI embedding model name to use. |
| 38 | """ |
| 39 | self.client, self.model = self._get_client_model(model_name) |
| 40 | |
| 41 | def _get_client_model(self, model_name: str) -> tuple[openai.OpenAI, str]: |
| 42 | if model_name in OPENAI_EMBEDDING_MODELS: |
| 43 | # Use OPENAI_EMBEDDING_API_KEY if set, otherwise fall back to OPENAI_API_KEY |
| 44 | # This allows users to use OpenRouter for LLMs while using OpenAI for embeddings |
| 45 | embedding_api_key = os.getenv("OPENAI_EMBEDDING_API_KEY") or os.getenv("OPENAI_API_KEY") |
| 46 | client = openai.OpenAI(api_key=embedding_api_key) |
| 47 | model_to_use = model_name |
| 48 | elif model_name in AZURE_EMBEDDING_MODELS: |
| 49 | # get rid of the azure- prefix |
| 50 | model_to_use = model_name.split("azure-")[-1] |
| 51 | client = openai.AzureOpenAI( |
| 52 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), |
| 53 | api_version=os.getenv("AZURE_API_VERSION"), |
| 54 | azure_endpoint=os.getenv("AZURE_API_ENDPOINT"), |
| 55 | ) |
| 56 | else: |
| 57 | raise ValueError(f"Invalid embedding model: {model_name}") |
| 58 | |
| 59 | return client, model_to_use |
| 60 | |
| 61 | def get_embedding(self, code: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: |
| 62 | """ |
| 63 | Computes the text embedding for a code string. |
| 64 | |
| 65 | Args: |
| 66 | code (str, list[str]): The code as a string or list |
| 67 | of strings. |
| 68 | |
| 69 | Returns: |
| 70 | list: Embedding vector for the code or None if an error |
| 71 | occurs. |
| 72 | """ |
| 73 | if isinstance(code, str): |
| 74 | code = [code] |
| 75 | single_code = True |
| 76 | else: |
| 77 | single_code = False |
| 78 | try: |
| 79 | response = self.client.embeddings.create( |
| 80 | model=self.model, input=code, encoding_format="float" |
| 81 | ) |
| 82 | # Extract embedding from response |
| 83 | if single_code: |
| 84 | return response.data[0].embedding |
| 85 | else: |
| 86 | return [d.embedding for d in response.data] |
| 87 | except Exception as e: |
| 88 | logger.info(f"Error getting embedding: {e}") |