Explorar o código

feat: collection name everywhere (#310)

Co-authored-by: cachho <admin@ch-webdev.com>
Jonas %!s(int64=2) %!d(string=hai) anos
pai
achega
eeac84e2d9

+ 4 - 0
docs/advanced/configuration.mdx

@@ -20,6 +20,10 @@ from chromadb.utils import embedding_functions
 config = AppConfig(log_level="DEBUG")
 naval_chat_bot = App(config)
 
+# Example: specify a custom collection name
+config = AppConfig(collection_name="naval_chat_bot")
+naval_chat_bot = App(config)
+
 # Example: define your own chunker config for `youtube_video`
 chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len)
 naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config))

+ 5 - 4
docs/advanced/query_configuration.mdx

@@ -4,11 +4,12 @@ title: '🔍 Query configurations'
 
 ## AppConfig
 
-| option      | description           | type                            | default                |
-|-------------|-----------------------|---------------------------------|------------------------|
-| log_level   | log level             | string                          | WARNING                |
+| option    | description           | type                            | default                |
+|-----------|-----------------------|---------------------------------|------------------------|
+| log_level | log level             | string                          | WARNING                |
 | embedding_fn| embedding function    | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
-| db          | vector database (experimental) | BaseVectorDB               | ChromaDB               |
+| db        | vector database (experimental) | BaseVectorDB               | ChromaDB               |
+| collection_name | initial collection name for the database | string             | embedchain_store |
 
 
 ## AddConfig

+ 8 - 2
embedchain/config/apps/AppConfig.py

@@ -16,16 +16,22 @@ class AppConfig(BaseAppConfig):
     Config to initialize an embedchain custom `App` instance, with extra config options.
     """
 
-    def __init__(self, log_level=None, host=None, port=None, id=None):
+    def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None):
         """
         :param log_level: Optional. (String) Debug level
         ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
         :param id: Optional. ID of the app. Document metadata will have this id.
+        :param collection_name: Optional. Collection name for the database.
         """
         super().__init__(
-            log_level=log_level, embedding_fn=AppConfig.default_embedding_function(), host=host, port=port, id=id
+            log_level=log_level,
+            embedding_fn=AppConfig.default_embedding_function(),
+            host=host,
+            port=port,
+            id=id,
+            collection_name=collection_name,
         )
 
     @staticmethod

+ 4 - 2
embedchain/config/apps/BaseAppConfig.py

@@ -8,19 +8,21 @@ class BaseAppConfig(BaseConfig):
     Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
     """
 
-    def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None):
+    def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None):
         """
         :param log_level: Optional. (String) Debug level
         ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
         :param embedding_fn: Embedding function to use.
         :param db: Optional. (Vector) database instance to use for embeddings.
-        :param id: Optional. ID of the app. Document metadata will have this id.
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
+        :param id: Optional. ID of the app. Document metadata will have this id.
+        :param collection_name: Optional. Collection name for the database.
         """
         self._setup_logging(log_level)
 
         self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
+        self.collection_name = collection_name if collection_name else "embedchain_store"
         self.id = id
         return
 

+ 4 - 2
embedchain/config/apps/CustomAppConfig.py

@@ -24,8 +24,8 @@ class CustomAppConfig(BaseAppConfig):
         host=None,
         port=None,
         id=None,
+        collection_name=None,
         provider: Providers = None,
-        model=None,
         open_source_app_config=None,
         deployment_name=None,
     ):
@@ -35,9 +35,10 @@ class CustomAppConfig(BaseAppConfig):
         :param embedding_fn: Optional. Embedding function to use.
         :param embedding_fn_model: Optional. Model name to use for embedding function.
         :param db: Optional. (Vector) database to use for embeddings.
-        :param id: Optional. ID of the app. Document metadata will have this id.
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
+        :param id: Optional. ID of the app. Document metadata will have this id.
+        :param collection_name: Optional. Collection name for the database.
         :param provider: Optional. (Providers): LLM Provider to use.
         :param open_source_app_config: Optional. Config instance needed for open source apps.
         """
@@ -58,6 +59,7 @@ class CustomAppConfig(BaseAppConfig):
             host=host,
             port=port,
             id=id,
+            collection_name=collection_name,
         )
 
     @staticmethod

+ 3 - 1
embedchain/config/apps/OpenSourceAppConfig.py

@@ -8,11 +8,12 @@ class OpenSourceAppConfig(BaseAppConfig):
     Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
     """
 
-    def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
+    def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None):
         """
         :param log_level: Optional. (String) Debug level
         ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
         :param id: Optional. ID of the app. Document metadata will have this id.
+        :param collection_name: Optional. Collection name for the database.
         :param host: Optional. Hostname for the database server.
         :param port: Optional. Port for the database server.
         :param model: Optional. GPT4ALL uses the model to instantiate the class.
@@ -26,6 +27,7 @@ class OpenSourceAppConfig(BaseAppConfig):
             host=host,
             port=port,
             id=id,
+            collection_name=collection_name,
         )
 
     @staticmethod

+ 9 - 1
embedchain/embedchain.py

@@ -32,7 +32,7 @@ class EmbedChain:
 
         self.config = config
         self.db_client = self.config.db.client
-        self.collection = self.config.db.collection
+        self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
         self.user_asks = []
         self.is_docs_site_instance = False
         self.online = False
@@ -325,6 +325,14 @@ class EmbedChain:
         memory.chat_memory.add_ai_message(streamed_answer)
         logging.info(f"Answer: {streamed_answer}")
 
+    def set_collection(self, collection_name):
+        """
+        Set the collection to use.
+
+        :param collection_name: The name of the collection to use.
+        """
+        self.collection = self.config.db._get_or_create_collection(collection_name)
+
     def count(self):
         """
         Count the number of embeddings.

+ 0 - 1
embedchain/vectordb/base_vector_db.py

@@ -3,7 +3,6 @@ class BaseVectorDB:
 
     def __init__(self):
         self.client = self._get_or_create_db()
-        self.collection = self._get_or_create_collection()
 
     def _get_or_create_db(self):
         """Get or create the database."""

+ 2 - 2
embedchain/vectordb/chroma_db.py

@@ -39,9 +39,9 @@ class ChromaDB(BaseVectorDB):
         """Get or create the database."""
         return self.client
 
-    def _get_or_create_collection(self):
+    def _get_or_create_collection(self, name):
         """Get or create the collection."""
         return self.client.get_or_create_collection(
-            "embedchain_store",
+            name=name,
             embedding_function=self.embedding_fn,
         )

+ 183 - 0
tests/vectordb/test_chroma_db.py

@@ -72,3 +72,186 @@ class TestChromaDbHostsLoglevel(unittest.TestCase):
 
         self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
         self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)
+
+class TestChromaDbDuplicateHandling:
+    def test_duplicates_throw_warning(self, caplog):
+        """
+        Test that add duplicates throws an error.
+        """
+        # Start with a clean app
+        App().reset()
+
+        app = App()
+        app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
+        app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
+        assert "Insert of existing embedding ID: 0" in caplog.text
+        assert "Add of existing embedding ID: 0" in caplog.text
+
+    def test_duplicates_collections_no_warning(self, caplog):
+        """
+        Test that different collections can have duplicates.
+        """
+        # NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.
+
+        # Start with a clean app
+        App().reset()
+
+        app = App()
+        app.set_collection("test_collection_1")
+        app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
+        app.set_collection("test_collection_2")
+        app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
+        assert "Insert of existing embedding ID: 0" not in caplog.text # not
+        assert "Add of existing embedding ID: 0" not in caplog.text # not
+
+
+class TestChromaDbCollection(unittest.TestCase):
+    def test_init_with_default_collection(self):
+        """
+        Test if the `App` instance is initialized with the correct default collection name.
+        """
+        app = App()
+
+        self.assertEqual(app.collection.name, "embedchain_store")
+
+    def test_init_with_custom_collection(self):
+        """
+        Test if the `App` instance is initialized with the correct custom collection name.
+        """
+        config = AppConfig(collection_name="test_collection")
+        app = App(config)
+
+        self.assertEqual(app.collection.name, "test_collection")
+
+    def test_set_collection(self):
+        """
+        Test if the `App` collection is correctly switched using the `set_collection` method.
+        """
+        app = App()
+        app.set_collection("test_collection")
+
+        self.assertEqual(app.collection.name, "test_collection")
+
+    def test_changes_encapsulated(self):
+        """
+        Test that changes to one collection do not affect the other collection
+        """
+        # Start with a clean app
+        App().reset()
+
+        app = App()
+        app.set_collection("test_collection_1")
+        # Collection should be empty when created
+        self.assertEqual(app.count(), 0)
+
+        app.collection.add(embeddings=[0, 0, 0], ids=["0"])
+        # After adding, should contain one item
+        self.assertEqual(app.count(), 1)
+
+        app.set_collection("test_collection_2")
+        # New collection is empty
+        self.assertEqual(app.count(), 0)
+
+        # Adding to new collection should not effect existing collection
+        app.collection.add(embeddings=[0, 0, 0], ids=["0"])
+        app.set_collection("test_collection_1")
+        # Should still be 1, not 2.
+        self.assertEqual(app.count(), 1)
+
+    def test_collections_are_persistent(self):
+        """
+        Test that a collection can be picked up later.
+        """
+        # Start with a clean app
+        App().reset()
+
+        app = App()
+        app.set_collection("test_collection_1")
+        app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
+        del app
+
+        app = App()
+        app.set_collection("test_collection_1")
+        self.assertEqual(app.count(), 1)
+
+    def test_parallel_collections(self):
+        """
+        Test that two apps can have different collections open in parallel.
+        Switching the names will allow instant access to the collection of
+        the other app.
+        """
+        # Start clean
+        App().reset()
+
+        # Create two apps
+        app1 = App(AppConfig(collection_name="test_collection_1"))
+        app2 = App(AppConfig(collection_name="test_collection_2"))
+
+        # app2 has been created last, but adding to app1 will still write to collection 1.
+        app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
+        self.assertEqual(app1.count(), 1)
+        self.assertEqual(app2.count(), 0)
+
+        # Add data
+        app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
+        app2.collection.add(embeddings=[0, 0, 0], ids=["0"])
+
+        # Swap names and test
+        app1.set_collection('test_collection_2')
+        self.assertEqual(app1.count(), 1)
+        app2.set_collection('test_collection_1')
+        self.assertEqual(app2.count(), 3)
+
+    def test_ids_share_collections(self):
+        """
+        Different ids should still share collections.
+        """
+        # Start clean
+        App().reset()
+
+        # Create two apps
+        app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
+        app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
+
+        # Add data
+        app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
+        app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
+
+        # Both should have the same collection
+        self.assertEqual(app1.count(), 3)
+        self.assertEqual(app2.count(), 3)
+
+    def test_reset(self):
+        """
+        Resetting should hit all collections and ids.
+        """
+        # Start clean
+        App().reset()
+
+        # Create four apps.
+        # app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
+        app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
+        app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
+        app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1"))
+        app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4"))
+
+        # Each one of them get data
+        app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
+        app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
+        app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
+        app4.collection.add(embeddings=[0, 0, 0], ids=["4"])
+
+        # Resetting the first one should reset them all.
+        app1.reset()
+
+        # Reinstantiate them
+        app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
+        app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
+        app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3"))
+        app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3"))
+
+        # All should be empty
+        self.assertEqual(app1.count(), 0)
+        self.assertEqual(app2.count(), 0)
+        self.assertEqual(app3.count(), 0)
+        self.assertEqual(app4.count(), 0)