(embedding_model_name: str, logger=BaseLogger(), config={})
| 33 | |
| 34 | |
| 35 | def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config={}): |
| 36 | if embedding_model_name == "ollama": |
| 37 | embeddings = OllamaEmbeddings( |
| 38 | base_url=config["ollama_base_url"], model="llama2" |
| 39 | ) |
| 40 | dimension = 4096 |
| 41 | logger.info("Embedding: Using Ollama") |
| 42 | elif embedding_model_name == "openai": |
| 43 | embeddings = OpenAIEmbeddings() |
| 44 | dimension = 1536 |
| 45 | logger.info("Embedding: Using OpenAI") |
| 46 | elif embedding_model_name == "aws": |
| 47 | embeddings = BedrockEmbeddings() |
| 48 | dimension = 1536 |
| 49 | logger.info("Embedding: Using AWS") |
| 50 | elif embedding_model_name == "google-genai-embedding-001": |
| 51 | embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") |
| 52 | dimension = 768 |
| 53 | logger.info("Embedding: Using Google Generative AI Embeddings") |
| 54 | else: |
| 55 | embeddings = HuggingFaceEmbeddings( |
| 56 | model_name="all-MiniLM-L6-v2", cache_folder="/embedding_model" |
| 57 | ) |
| 58 | dimension = 384 |
| 59 | logger.info("Embedding: Using SentenceTransformer") |
| 60 | return embeddings, dimension |
| 61 | |
| 62 | |
| 63 | def load_llm(llm_name: str, logger=BaseLogger(), config={}): |
no test coverage detected