| 61 | |
| 62 | |
| 63 | class PostgresDB(VectorStore): |
| 64 | def __init__(self, config: PostgresDBConfig = PostgresDBConfig()): |
| 65 | super().__init__(config) |
| 66 | if not has_postgres: |
| 67 | raise LangroidImportError("pgvector", "postgres") |
| 68 | try: |
| 69 | from sqlalchemy.orm import sessionmaker |
| 70 | except ImportError: |
| 71 | raise LangroidImportError("sqlalchemy", "postgres") |
| 72 | |
| 73 | self.config: PostgresDBConfig = config |
| 74 | self.engine = self._create_engine() |
| 75 | PostgresDB._create_vector_extension(self.engine) |
| 76 | self.SessionLocal = sessionmaker( |
| 77 | autocommit=False, autoflush=False, bind=self.engine |
| 78 | ) |
| 79 | self.metadata = MetaData() |
| 80 | self._setup_table() |
| 81 | |
| 82 | def _create_engine(self) -> Engine: |
| 83 | """Creates a SQLAlchemy engine based on the configuration.""" |
| 84 | |
| 85 | connection_string: str | None = None # Ensure variable is always defined |
| 86 | |
| 87 | if self.config.cloud: |
| 88 | connection_string = os.getenv("POSTGRES_CONNECTION_STRING") |
| 89 | |
| 90 | if connection_string and connection_string.startswith("postgres://"): |
| 91 | connection_string = connection_string.replace( |
| 92 | "postgres://", "postgresql+psycopg2://", 1 |
| 93 | ) |
| 94 | elif not connection_string: |
| 95 | raise ValueError("Provide the POSTGRES_CONNECTION_STRING.") |
| 96 | |
| 97 | elif self.config.docker: |
| 98 | username = os.getenv("POSTGRES_USER", "postgres") |
| 99 | password = os.getenv("POSTGRES_PASSWORD", "postgres") |
| 100 | database = os.getenv("POSTGRES_DB", "langroid") |
| 101 | |
| 102 | if not (username and password and database): |
| 103 | raise ValueError( |
| 104 | "Provide POSTGRES_USER, POSTGRES_PASSWORD, " "POSTGRES_DB. " |
| 105 | ) |
| 106 | |
| 107 | connection_string = ( |
| 108 | f"postgresql+psycopg2://{username}:{password}@" |
| 109 | f"{self.config.host}:{self.config.port}/{database}" |
| 110 | ) |
| 111 | self.config.cloud = False # Ensures cloud is disabled if using Docker |
| 112 | |
| 113 | else: |
| 114 | raise ValueError( |
| 115 | "Provide either Docker or Cloud config to connect to the database." |
| 116 | ) |
| 117 | |
| 118 | return create_engine( |
| 119 | connection_string, |
| 120 | pool_size=self.config.pool_size, |