Przeglądaj źródła

Merge pull request #22 from DumoeDss/feature_add_other_vectordb

[feat] Refactor VectorDB class hierarchy for flexibility
Taranjeet Singh 2 lat temu
rodzic
commit
21527e417a

+ 10 - 35
embedchain/embedchain.py

@@ -1,8 +1,6 @@
-import chromadb
 import openai
 import os
 
-from chromadb.utils import embedding_functions
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.embeddings.openai import OpenAIEmbeddings
@@ -13,6 +11,7 @@ from embedchain.loaders.web_page import WebPageLoader
 from embedchain.chunkers.youtube_video import YoutubeVideoChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.web_page import WebPageChunker
+from embedchain.vectordb.chroma_db import ChromaDB
 
 load_dotenv()
 
@@ -21,20 +20,19 @@ embeddings = OpenAIEmbeddings()
 ABS_PATH = os.getcwd()
 DB_DIR = os.path.join(ABS_PATH, "db")
 
-openai_ef = embedding_functions.OpenAIEmbeddingFunction(
-    api_key=os.getenv("OPENAI_API_KEY"),
-    model_name="text-embedding-ada-002"
-)
-
 
 class EmbedChain:
-    def __init__(self):
+    def __init__(self, db=None):
         """
-        Initializes the EmbedChain instance, sets up a ChromaDB client and
-        creates a ChromaDB collection.
+         Initializes the EmbedChain instance, sets up a vector DB client and
+        creates a collection.
+
+        :param db: The instance of the VectorDB subclass.
         """
-        self.chromadb_client = self._get_or_create_db()
-        self.collection = self._get_or_create_collection()
+        if db is None:
+            db = ChromaDB()
+        self.db_client = db.client
+        self.collection = db.collection
         self.user_asks = []
 
     def _get_loader(self, data_type):
@@ -87,29 +85,6 @@ class EmbedChain:
         self.user_asks.append([data_type, url])
         self.load_and_embed(loader, chunker, url)
 
-    def _get_or_create_db(self):
-        """
-        Returns a ChromaDB client, creates a new one if needed.
-
-        :return: The ChromaDB client.
-        """
-        client_settings = chromadb.config.Settings(
-            chroma_db_impl="duckdb+parquet",
-            persist_directory=DB_DIR,
-            anonymized_telemetry=False
-        )
-        return chromadb.Client(client_settings)
-
-    def _get_or_create_collection(self):
-        """
-        Returns a ChromaDB collection, creates a new one if needed.
-
-        :return: The ChromaDB collection.
-        """
-        return self.chromadb_client.get_or_create_collection(
-            'embedchain_store', embedding_function=openai_ef,
-        )
-
     def load_and_embed(self, loader, chunker, url):
         """
         Loads the data from the given URL, chunks it, and adds it to the database.

+ 10 - 0
embedchain/vectordb/base_vector_db.py

@@ -0,0 +1,10 @@
+class BaseVectorDB:
+    def __init__(self):
+        self.client = self._get_or_create_db()
+        self.collection = self._get_or_create_collection()
+
+    def _get_or_create_db(self):
+        raise NotImplementedError
+
+    def _get_or_create_collection(self):
+        raise NotImplementedError

+ 30 - 0
embedchain/vectordb/chroma_db.py

@@ -0,0 +1,30 @@
+import chromadb
+import os
+
+from chromadb.utils import embedding_functions
+
+from embedchain.vectordb.base_vector_db import BaseVectorDB
+
+openai_ef = embedding_functions.OpenAIEmbeddingFunction(
+    api_key=os.getenv("OPENAI_API_KEY"),
+    model_name="text-embedding-ada-002"
+)
+
+class ChromaDB(BaseVectorDB):
+    def __init__(self, db_dir=None):
+        if db_dir is None:
+            db_dir = "db"
+        self.client_settings = chromadb.config.Settings(
+            chroma_db_impl="duckdb+parquet",
+            persist_directory=db_dir,
+            anonymized_telemetry=False
+        )
+        super().__init__()
+
+    def _get_or_create_db(self):
+        return chromadb.Client(self.client_settings)
+
+    def _get_or_create_collection(self):
+        return self.client.get_or_create_collection(
+            'embedchain_store', embedding_function=openai_ef,
+        )