Forráskód Böngészése

Upgrade the chromadb version to 0.4.8 and open its settings configuration. (#517)

wangJm 2 éve
szülő
commit
eecdbc5e06

+ 5 - 2
embedchain/config/apps/BaseAppConfig.py

@@ -24,6 +24,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
         db_type: VectorDatabases = None,
         vector_dim: VectorDimensions = None,
         es_config: ElasticsearchDBConfig = None,
+        chroma_settings: dict = {},
     ):
         """
         :param log_level: Optional. (String) Debug level
@@ -38,6 +39,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
         :param db_type: Optional. type of Vector database to use
         :param vector_dim: Vector dimension generated by embedding fn
         :param es_config: Optional. elasticsearch database config to be used for connection
+        :param chroma_settings: Optional. Chroma settings for connection.
         """
         self._setup_logging(log_level)
         self.collection_name = collection_name if collection_name else "embedchain_store"
@@ -50,13 +52,14 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
             vector_dim=vector_dim,
             collection_name=self.collection_name,
             es_config=es_config,
+            chroma_settings=chroma_settings,
         )
         self.id = id
         self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
         return
 
     @staticmethod
-    def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config):
+    def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
         """
         Get db based on db_type, db with default database (`ChromaDb`)
         :param Optional. (Vector) database to use for embeddings.
@@ -85,7 +88,7 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
 
         from embedchain.vectordb.chroma_db import ChromaDB
 
-        return ChromaDB(embedding_fn=embedding_fn, host=host, port=port)
+        return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)
 
     def _setup_logging(self, debug_level):
         level = logging.WARNING  # Default level

+ 3 - 0
embedchain/config/apps/CustomAppConfig.py

@@ -35,6 +35,7 @@ class CustomAppConfig(BaseAppConfig):
         collect_metrics: Optional[bool] = None,
         db_type: VectorDatabases = None,
         es_config: ElasticsearchDBConfig = None,
+        chroma_settings: dict = {},
     ):
         """
         :param log_level: Optional. (String) Debug level
@@ -51,6 +52,7 @@ class CustomAppConfig(BaseAppConfig):
         :param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
         :param db_type: Optional. type of Vector database to use.
         :param es_config: Optional. elasticsearch database config to be used for connection
+        :param chroma_settings: Optional. Chroma settings for connection.
         """
         if provider:
             self.provider = provider
@@ -73,6 +75,7 @@ class CustomAppConfig(BaseAppConfig):
             db_type=db_type,
             vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
             es_config=es_config,
+            chroma_settings=chroma_settings,
         )
 
     @staticmethod

+ 15 - 7
embedchain/vectordb/chroma_db.py

@@ -22,23 +22,31 @@ from embedchain.vectordb.base_vector_db import BaseVectorDB
 class ChromaDB(BaseVectorDB):
     """Vector database using ChromaDB."""
 
-    def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
+    def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
         self.embedding_fn = embedding_fn
 
         if not hasattr(embedding_fn, "__call__"):
             raise ValueError("Embedding function is not a function")
 
+        self.settings = Settings()
+        for key, value in chroma_settings.items():
+            if hasattr(self.settings, key):
+                setattr(self.settings, key, value)
+
         if host and port:
             logging.info(f"Connecting to ChromaDB server: {host}:{port}")
-            self.client = chromadb.HttpClient(host=host, port=port)
+            self.settings.chroma_server_host = host
+            self.settings.chroma_server_http_port = port
+            self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
+
         else:
             if db_dir is None:
                 db_dir = "db"
-            self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
-            self.client = chromadb.PersistentClient(
-                path=db_dir,
-                settings=self.settings,
-            )
+
+            self.settings.persist_directory = db_dir
+            self.settings.is_persistent = True
+
+        self.client = chromadb.Client(self.settings)
         super().__init__()
 
     def _get_or_create_db(self):

+ 1 - 1
pyproject.toml

@@ -86,7 +86,7 @@ langchain = "^0.0.279"
 requests = "^2.31.0"
 openai = "^0.27.5"
 tiktoken = "^0.4.0"
-chromadb ="^0.4.2"
+chromadb ="^0.4.8"
 youtube-transcript-api = "^0.6.1"
 beautifulsoup4 = "^4.12.2"
 pypdf = "^3.11.0"

+ 7 - 2
tests/embedchain/test_embedchain.py

@@ -3,7 +3,8 @@ import unittest
 from unittest.mock import patch
 
 from embedchain import App
-from embedchain.config import AppConfig
+from embedchain.config import AppConfig, CustomAppConfig
+from embedchain.models import EmbeddingFunctions, Providers
 
 
 class TestChromaDbHostsLoglevel(unittest.TestCase):
@@ -42,7 +43,11 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
         """
         Test if the `App` instance is correctly reconstructed after a reset.
         """
-        app = App()
+        app = App(
+            CustomAppConfig(
+                provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
+            )
+        )
         app.reset()
 
         # Make sure the client is still healthy

+ 49 - 9
tests/vectordb/test_chroma_db.py

@@ -4,7 +4,8 @@ import unittest
 from unittest.mock import patch
 
 from embedchain import App
-from embedchain.config import AppConfig
+from embedchain.config import AppConfig, CustomAppConfig
+from embedchain.models import EmbeddingFunctions, Providers
 from embedchain.vectordb.chroma_db import ChromaDB
 
 
@@ -21,6 +22,24 @@ class TestChromaDbHosts(unittest.TestCase):
         self.assertEqual(settings.chroma_server_host, host)
         self.assertEqual(settings.chroma_server_http_port, port)
 
+    def test_init_with_basic_auth(self):
+        host = "test-host"
+        port = "1234"
+
+        chroma_auth_settings = {
+            "chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
+            "chroma_client_auth_credentials": "admin:admin",
+        }
+
+        db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
+        settings = db.client.get_settings()
+        self.assertEqual(settings.chroma_server_host, host)
+        self.assertEqual(settings.chroma_server_http_port, port)
+        self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
+        self.assertEqual(
+            settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
+        )
+
 
 # Review this test
 class TestChromaDbHostsInit(unittest.TestCase):
@@ -68,12 +87,18 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
 
 
 class TestChromaDbDuplicateHandling:
+    app_with_settings = App(
+        CustomAppConfig(
+            provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
+        )
+    )
+
     def test_duplicates_throw_warning(self, caplog):
         """
         Test that add duplicates throws an error.
         """
         # Start with a clean app
-        App().reset()
+        self.app_with_settings.reset()
 
         app = App(config=AppConfig(collect_metrics=False))
         app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
@@ -88,7 +113,7 @@ class TestChromaDbDuplicateHandling:
         # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
 
         # Start with a clean app
-        App().reset()
+        self.app_with_settings.reset()
 
         app = App(config=AppConfig(collect_metrics=False))
         app.set_collection("test_collection_1")
@@ -100,6 +125,12 @@ class TestChromaDbDuplicateHandling:
 
 
 class TestChromaDbCollection(unittest.TestCase):
+    app_with_settings = App(
+        CustomAppConfig(
+            provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
+        )
+    )
+
     def test_init_with_default_collection(self):
         """
         Test if the `App` instance is initialized with the correct default collection name.
@@ -131,7 +162,7 @@ class TestChromaDbCollection(unittest.TestCase):
         Test that changes to one collection do not affect the other collection
         """
         # Start with a clean app
-        App().reset()
+        self.app_with_settings.reset()
 
         app = App(config=AppConfig(collect_metrics=False))
         app.set_collection("test_collection_1")
@@ -157,7 +188,7 @@ class TestChromaDbCollection(unittest.TestCase):
         Test that a collection can be picked up later.
         """
         # Start with a clean app
-        App().reset()
+        self.app_with_settings.reset()
 
         app = App(config=AppConfig(collect_metrics=False))
         app.set_collection("test_collection_1")
@@ -175,7 +206,7 @@ class TestChromaDbCollection(unittest.TestCase):
         the other app.
         """
         # Start clean
-        App().reset()
+        self.app_with_settings.reset()
 
         # Create two apps
         app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
@@ -201,7 +232,7 @@ class TestChromaDbCollection(unittest.TestCase):
         Different ids should still share collections.
         """
         # Start clean
-        App().reset()
+        self.app_with_settings.reset()
 
         # Create two apps
         app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
@@ -220,11 +251,20 @@ class TestChromaDbCollection(unittest.TestCase):
         Resetting should hit all collections and ids.
         """
         # Start clean
-        App().reset()
+        self.app_with_settings.reset()
 
         # Create four apps.
         # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
-        app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
+        app1 = App(
+            CustomAppConfig(
+                collection_name="one_collection",
+                id="new_app_id_1",
+                collect_metrics=False,
+                provider=Providers.OPENAI,
+                embedding_fn=EmbeddingFunctions.OPENAI,
+                chroma_settings={"allow_reset": True},
+            )
+        )
         app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
         app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
         app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))