Manages vector embeddings for graph nodes in SQLite.
| 851 | |
| 852 | |
| 853 | class EmbeddingStore: |
| 854 | """Manages vector embeddings for graph nodes in SQLite.""" |
| 855 | |
| 856 | def __init__( |
| 857 | self, |
| 858 | db_path: str | Path, |
| 859 | provider: str | None = None, |
| 860 | model: str | None = None, |
| 861 | ) -> None: |
| 862 | self.provider = get_provider(provider, model=model) |
| 863 | self.available = self.provider is not None |
| 864 | self.db_path = Path(db_path) |
| 865 | self._conn = sqlite3.connect( |
| 866 | str(self.db_path), timeout=30, check_same_thread=False, |
| 867 | isolation_level=None, |
| 868 | ) |
| 869 | self._conn.row_factory = sqlite3.Row |
| 870 | self._conn.executescript(_EMBEDDINGS_SCHEMA) |
| 871 | |
| 872 | # Migration for existing DBs missing the provider column |
| 873 | try: |
| 874 | self._conn.execute("SELECT provider FROM embeddings LIMIT 1") |
| 875 | except sqlite3.OperationalError: |
| 876 | self._conn.execute( |
| 877 | "ALTER TABLE embeddings ADD COLUMN provider " |
| 878 | "TEXT NOT NULL DEFAULT 'unknown'" |
| 879 | ) |
| 880 | |
| 881 | self._conn.commit() |
| 882 | |
| 883 | def __enter__(self) -> "EmbeddingStore": |
| 884 | return self |
| 885 | |
| 886 | def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore[no-untyped-def] |
| 887 | self.close() |
| 888 | |
| 889 | def close(self) -> None: |
| 890 | self._conn.close() |
| 891 | |
| 892 | def embed_nodes(self, nodes: list[GraphNode], batch_size: int = 64) -> int: |
| 893 | """Compute and store embeddings for a list of nodes.""" |
| 894 | if not self.provider: |
| 895 | return 0 |
| 896 | |
| 897 | # Filter to nodes that need embedding |
| 898 | to_embed: list[tuple[GraphNode, str, str]] = [] |
| 899 | provider_name = self.provider.name |
| 900 | |
| 901 | for node in nodes: |
| 902 | if node.kind == "File": |
| 903 | continue |
| 904 | text = _node_to_text(node) |
| 905 | text_hash = hashlib.sha256(text.encode()).hexdigest() |
| 906 | |
| 907 | existing = self._conn.execute( |
| 908 | "SELECT text_hash, provider FROM embeddings WHERE qualified_name = ?", |
| 909 | (node.qualified_name,), |
| 910 | ).fetchone() |
no outgoing calls