Преглед изворни кода

fix: reset destroys app (#319)

Co-authored-by: cachho <admin@ch-webdev.com>
Jonas пре 2 година
родитељ
комит
28e06be26f
3 измењених фајлова са 34 додато и 5 уклоњено
  1. 11 3
      embedchain/embedchain.py
  2. 22 0
      tests/embedchain/test_embedchain.py
  3. 1 2
      tests/vectordb/test_chroma_db.py

+ 11 - 3
embedchain/embedchain.py

@@ -372,13 +372,21 @@ class EmbedChain:
     def reset(self):
         """
         Resets the database. Deletes all embeddings irreversibly.
-        `App` has to be reinitialized after using this method.
+        `App` does not have to be reinitialized after using this method.
         """
         # Send anonymous telemetry
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("reset",))
         thread_telemetry.start()
-
+      
+        collection_name = self.collection.name
         self.db.reset()
+        self.collection = self.config.db._get_or_create_collection(collection_name)
+        # Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
+        # A downside of this implementation is, if you have two instances,
+        # the other instance will not get the updated `self.collection` attribute.
+        # A better way would be to create the collection if it is called again after being reset.
+        # That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
+        # That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.
 
     @retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
     def _send_telemetry_event(self, method: str, extra_metadata: Optional[dict] = None):
@@ -397,4 +405,4 @@ class EmbedChain:
                 metadata.update(extra_metadata)
 
             response = requests.post(url, json={"metadata": metadata})
-            response.raise_for_status()
+            response.raise_for_status()

+ 22 - 0
tests/embedchain/test_embedchain.py

@@ -37,3 +37,25 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
         app.chat("What text did I give you?")
 
         self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge])
+
+    def test_add_after_reset(self):
+        """
+        Test if the `App` instance is correctly reconstructed after a reset.
+        """
+        app = App()
+        app.reset()
+
+        # Make sure the client is still healthy
+        app.db.client.heartbeat()
+        # Make sure the collection exists, and can be added to
+        app.collection.add(
+            embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2]],
+            metadatas=[
+                {"chapter": "3", "verse": "16"},
+                {"chapter": "3", "verse": "5"},
+                {"chapter": "29", "verse": "11"},
+            ],
+            ids=["id1", "id2", "id3"],
+        )
+
+        app.reset()

+ 1 - 2
tests/vectordb/test_chroma_db.py

@@ -245,8 +245,7 @@ class TestChromaDbCollection(unittest.TestCase):
         # Resetting the first one should reset them all.
         app1.reset()
 
-        # Reinstantiate them
-        app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
+        # Reinstantiate app2-4, app1 doesn't have to be reinstantiated (PR #319)
         app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
         app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3", collect_metrics=False))
         app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3", collect_metrics=False))