Преглед изворни кода

allow_reset as constructor argument (#618)

Dev Khant пре 1 година
родитељ
комит
01fb216ff7

+ 5 - 3
embedchain/config/vectordbs/ChromaDbConfig.py

@@ -12,6 +12,7 @@ class ChromaDbConfig(BaseVectorDbConfig):
         dir: Optional[str] = None,
         host: Optional[str] = None,
         port: Optional[str] = None,
+        allow_reset=False,
         chroma_settings: Optional[dict] = None,
     ):
         """
@@ -25,11 +26,12 @@ class ChromaDbConfig(BaseVectorDbConfig):
         :type host: Optional[str], optional
         :param port: Database connection remote port. Use this if you run Embedchain as a client, defaults to None
         :type port: Optional[str], optional
+        :param allow_reset: Resets the database. defaults to False
+        :type allow_reset: bool
         :param chroma_settings: Chroma settings dict, defaults to None
         :type chroma_settings: Optional[dict], optional
         """
-        """
-        :param chroma_settings: Optional. Chroma settings for connection.
-        """
+
         self.chroma_settings = chroma_settings
+        self.allow_reset = allow_reset
         super().__init__(collection_name=collection_name, dir=dir, host=host, port=port)

+ 3 - 2
embedchain/vectordb/chroma.py

@@ -37,6 +37,7 @@ class ChromaDB(BaseVectorDB):
             self.config = ChromaDbConfig()
 
         self.settings = Settings()
+        self.settings.allow_reset = self.config.allow_reset
         if self.config.chroma_settings:
             for key, value in self.config.chroma_settings.items():
                 if hasattr(self.settings, key):
@@ -208,8 +209,8 @@ class ChromaDB(BaseVectorDB):
             self.client.reset()
         except ValueError:
             raise ValueError(
-                "For safety reasons, resetting is disabled."
-                'Please enable it by including `chromadb_settings={"allow_reset": True}` in your ChromaDbConfig'
+                "For safety reasons, resetting is disabled. "
+                "Please enable it by setting `allow_reset=True` in your ChromaDbConfig"
             ) from None
         # Recreate
         self._get_or_create_collection(self.config.collection_name)

+ 2 - 4
tests/vectordb/test_chroma_db.py

@@ -93,8 +93,7 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
 
 
 class TestChromaDbDuplicateHandling:
-    chroma_settings = {"allow_reset": True}
-    chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
+    chroma_config = ChromaDbConfig(allow_reset=True)
     app_config = AppConfig(collection_name=False, collect_metrics=False)
     app_with_settings = App(config=app_config, chromadb_config=chroma_config)
 
@@ -130,8 +129,7 @@ class TestChromaDbDuplicateHandling:
 
 
 class TestChromaDbCollection(unittest.TestCase):
-    chroma_settings = {"allow_reset": True}
-    chroma_config = ChromaDbConfig(chroma_settings=chroma_settings)
+    chroma_config = ChromaDbConfig(allow_reset=True)
     app_config = AppConfig(collection_name=False, collect_metrics=False)
     app_with_settings = App(config=app_config, chromadb_config=chroma_config)