Quellcode durchsuchen

[feat] Refactor VectorDB class hierarchy for flexibility

Sayo vor 2 Jahren
Ursprung
Commit
85a6a0c161
3 geänderte Dateien mit 43 neuen und 35 gelöschten Zeilen
  1. 7 35
      embedchain/embedchain.py
  2. 10 0
      embedchain/vectordb/base_vector_db.py
  3. 26 0
      embedchain/vectordb/chroma_db.py

+ 7 - 35
embedchain/embedchain.py

@@ -1,8 +1,6 @@
-import chromadb
 import openai
 import os
 
-from chromadb.utils import embedding_functions
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.embeddings.openai import OpenAIEmbeddings
@@ -21,20 +19,17 @@ embeddings = OpenAIEmbeddings()
 ABS_PATH = os.getcwd()
 DB_DIR = os.path.join(ABS_PATH, "db")
 
-openai_ef = embedding_functions.OpenAIEmbeddingFunction(
-    api_key=os.getenv("OPENAI_API_KEY"),
-    model_name="text-embedding-ada-002"
-)
-
 
 class EmbedChain:
-    def __init__(self):
+    def __init__(self, db):
         """
-        Initializes the EmbedChain instance, sets up a ChromaDB client and
-        creates a ChromaDB collection.
+         Initializes the EmbedChain instance, sets up a vector DB client and
+        creates a collection.
+
+        :param db: The instance of the VectorDB subclass.
         """
-        self.chromadb_client = self._get_or_create_db()
-        self.collection = self._get_or_create_collection()
+        self.db_client = db.client
+        self.collection = db.collection
         self.user_asks = []
 
     def _get_loader(self, data_type):
@@ -87,29 +82,6 @@ class EmbedChain:
         self.user_asks.append([data_type, url])
         self.load_and_embed(loader, chunker, url)
 
-    def _get_or_create_db(self):
-        """
-        Returns a ChromaDB client, creates a new one if needed.
-
-        :return: The ChromaDB client.
-        """
-        client_settings = chromadb.config.Settings(
-            chroma_db_impl="duckdb+parquet",
-            persist_directory=DB_DIR,
-            anonymized_telemetry=False
-        )
-        return chromadb.Client(client_settings)
-
-    def _get_or_create_collection(self):
-        """
-        Returns a ChromaDB collection, creates a new one if needed.
-
-        :return: The ChromaDB collection.
-        """
-        return self.chromadb_client.get_or_create_collection(
-            'embedchain_store', embedding_function=openai_ef,
-        )
-
     def load_and_embed(self, loader, chunker, url):
         """
         Loads the data from the given URL, chunks it, and adds it to the database.

+ 10 - 0
embedchain/vectordb/base_vector_db.py

@@ -0,0 +1,10 @@
+class BaseVectorDB:
+    def __init__(self):
+        self.client = self._get_or_create_db()
+        self.collection = self._get_or_create_collection()
+
+    def _get_or_create_db(self):
+        raise NotImplementedError
+
+    def _get_or_create_collection(self):
+        raise NotImplementedError

+ 26 - 0
embedchain/vectordb/chroma_db.py

@@ -0,0 +1,26 @@
+import os
+import chromadb
+from base_vector_db import BaseVectorDB
+from chromadb.utils import embedding_functions
+
+openai_ef = embedding_functions.OpenAIEmbeddingFunction(
+    api_key=os.getenv("OPENAI_API_KEY"),
+    model_name="text-embedding-ada-002"
+)
+
+class ChromaDB(BaseVectorDB):
+    def __init__(self, db_dir):
+        self.client_settings = chromadb.config.Settings(
+            chroma_db_impl="duckdb+parquet",
+            persist_directory=db_dir,
+            anonymized_telemetry=False
+        )
+        super().__init__()
+
+    def _get_or_create_db(self):
+        return chromadb.Client(self.client_settings)
+
+    def _get_or_create_collection(self):
+        return self.client.get_or_create_collection(
+            'embedchain_store', embedding_function=openai_ef,
+        )