Get text embeddings using OpenAI's text-embedding-3-small model. Args: texts (Union[str, List[str]]): A single text string or a list of text strings to embed. max_workers (int): The maximum number of workers for parallel processing. api_key
(
self,
texts: Union[str, List[str]],
max_workers: int = 5,
)
| 130 | return text, embedding, token_usage |
| 131 | |
| 132 | def _get_text_embeddings( |
| 133 | self, |
| 134 | texts: Union[str, List[str]], |
| 135 | max_workers: int = 5, |
| 136 | ) -> Tuple[np.ndarray, int]: |
| 137 | """ |
| 138 | Get text embeddings using OpenAI's text-embedding-3-small model. |
| 139 | |
| 140 | Args: |
| 141 | texts (Union[str, List[str]]): A single text string or a list of text strings to embed. |
| 142 | max_workers (int): The maximum number of workers for parallel processing. |
| 143 | api_key (str): The API key for accessing OpenAI's services. |
| 144 | embedding_cache (Optional[Dict[str, np.ndarray]]): A cache to store previously computed embeddings. |
| 145 | |
| 146 | Returns: |
| 147 | Tuple[np.ndarray, int]: The 2D array of embeddings and the total token usage. |
| 148 | """ |
| 149 | |
| 150 | if isinstance(texts, str): |
| 151 | _, embedding, tokens = self._get_single_text_embedding(texts) |
| 152 | self.total_token_usage += tokens |
| 153 | return np.array(embedding) |
| 154 | |
| 155 | embeddings = [] |
| 156 | total_tokens = 0 |
| 157 | |
| 158 | with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| 159 | futures = { |
| 160 | executor.submit(self._get_single_text_embedding, text): text |
| 161 | for text in texts |
| 162 | } |
| 163 | |
| 164 | for future in as_completed(futures): |
| 165 | try: |
| 166 | text, embedding, tokens = future.result() |
| 167 | embeddings.append((text, embedding, tokens)) |
| 168 | total_tokens += tokens |
| 169 | except Exception as e: |
| 170 | print(f"An error occurred for text: {futures[future]}") |
| 171 | print(e) |
| 172 | |
| 173 | # Sort results to match the order of the input texts |
| 174 | embeddings.sort(key=lambda x: texts.index(x[0])) |
| 175 | embeddings = [result[1] for result in embeddings] |
| 176 | self.total_token_usage += total_tokens |
| 177 | |
| 178 | return np.array(embeddings) |
no test coverage detected