Procházet zdrojové kódy

feat: where filter in vector database (#518)

sw8fbar před 1 rokem
rodič
revize
3e66ddf69a

+ 5 - 6
embedchain/apps/PersonApp.py

@@ -4,8 +4,7 @@ from embedchain.apps.App import App
 from embedchain.apps.OpenSourceApp import OpenSourceApp
 from embedchain.config import ChatConfig, QueryConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
-from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
-                                           DEFAULT_PROMPT_WITH_HISTORY)
+from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
 from embedchain.helper_classes.json_serializable import register_deserializable
 
 
@@ -60,12 +59,12 @@ class PersonApp(EmbedChainPersonApp, App):
     """
 
     def query(self, input_query, config: QueryConfig = None, dry_run=False):
-        config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
-        return super().query(input_query, config, dry_run)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
+        return super().query(input_query, config, dry_run, where=None)
 
-    def chat(self, input_query, config: ChatConfig = None, dry_run=False):
+    def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
         config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
-        return super().chat(input_query, config, dry_run)
+        return super().chat(input_query, config, dry_run, where)
 
 
 @register_deserializable

+ 3 - 0
embedchain/config/ChatConfig.py

@@ -38,6 +38,7 @@ class ChatConfig(QueryConfig):
         stream: bool = False,
         deployment_name=None,
         system_prompt: Optional[str] = None,
+        where=None,
     ):
         """
         Initializes the ChatConfig instance.
@@ -57,6 +58,7 @@ class ChatConfig(QueryConfig):
         :param stream: Optional. Control if response is streamed back to the user
         :param deployment_name: t.b.a.
         :param system_prompt: Optional. System prompt string.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :raises ValueError: If the template is not valid as template should contain
         $context and $query and $history
         """
@@ -77,6 +79,7 @@ class ChatConfig(QueryConfig):
             stream=stream,
             deployment_name=deployment_name,
             system_prompt=system_prompt,
+            where=where,
         )
 
     def set_history(self, history):

+ 3 - 0
embedchain/config/QueryConfig.py

@@ -67,6 +67,7 @@ class QueryConfig(BaseConfig):
         stream: bool = False,
         deployment_name=None,
         system_prompt: Optional[str] = None,
+        where=None,
     ):
         """
         Initializes the QueryConfig instance.
@@ -87,6 +88,7 @@ class QueryConfig(BaseConfig):
         :param stream: Optional. Control if response is streamed back to user
         :param deployment_name: t.b.a.
         :param system_prompt: Optional. System prompt string.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :raises ValueError: If the template is not valid as template should
         contain $context and $query (and optionally $history).
         """
@@ -127,6 +129,7 @@ class QueryConfig(BaseConfig):
         if not isinstance(stream, bool):
             raise ValueError("`stream` should be bool")
         self.stream = stream
+        self.where = where
 
     def validate_template(self, template: Template):
         """

+ 19 - 6
embedchain/embedchain.py

@@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable):
         """
         raise NotImplementedError
 
-    def retrieve_from_database(self, input_query, config: QueryConfig):
+    def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
 
         :param input_query: The query to use.
         :param config: The query configuration.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :return: The content of the document that matched your query.
         """
-        where = {"app_id": self.config.id} if self.config.id is not None else {}  # optional filter
+
+        if where is not None:
+            where = where
+        elif config is not None and config.where is not None:
+            where = config.where
+        else:
+            where = {}
+
+        if self.config.id is not None:
+            where.update({"app_id": self.config.id})
+
         contents = self.db.query(
             input_query=input_query,
             n_results=config.number_documents,
@@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable):
         logging.info(f"Access search to get answers for {input_query}")
         return search.run(input_query)
 
-    def query(self, input_query, config: QueryConfig = None, dry_run=False):
+    def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable):
         by the vector database's doc retrieval.
         The only thing the dry run does not consider is the cut-off due to
         the `max_tokens` parameter.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :return: The answer to the query.
         """
         if config is None:
@@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable):
         k = {}
         if self.online:
             k["web_search_result"] = self.access_search_and_get_results(input_query)
-        contexts = self.retrieve_from_database(input_query, config)
+        contexts = self.retrieve_from_database(input_query, config, where)
         prompt = self.generate_prompt(input_query, contexts, config, **k)
         logging.info(f"Prompt: {prompt}")
 
@@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable):
             yield chunk
         logging.info(f"Answer: {streamed_answer}")
 
-    def chat(self, input_query, config: ChatConfig = None, dry_run=False):
+    def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
         """
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable):
         by the vector database's doc retrieval.
         The only thing the dry run does not consider is the cut-off due to
         the `max_tokens` parameter.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :return: The answer to the query.
         """
         if config is None:
@@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable):
         k = {}
         if self.online:
             k["web_search_result"] = self.access_search_and_get_results(input_query)
-        contexts = self.retrieve_from_database(input_query, config)
+        contexts = self.retrieve_from_database(input_query, config, where)
 
         chat_history = self.memory.load_memory_variables({})["history"]
 

+ 64 - 2
tests/embedchain/test_chat.py

@@ -1,9 +1,9 @@
 import os
 import unittest
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 from embedchain import App
-from embedchain.config import AppConfig
+from embedchain.config import AppConfig, ChatConfig
 
 
 class TestApp(unittest.TestCase):
@@ -35,3 +35,65 @@ class TestApp(unittest.TestCase):
         second_answer = app.chat("Test query 2")
         self.assertEqual(second_answer, "Test answer")
         self.assertEqual(len(app.memory.chat_memory.messages), 4)
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_chat_with_where_in_params(self):
+        """
+        This test checks the functionality of the 'chat' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'chat' method is expected to call 'retrieve_from_database' with the where filter  and
+        'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'chat' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                answer = self.app.chat("Test chat", where={"attribute": "value"})
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
+        self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
+        mock_answer.assert_called_once()
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_chat_with_where_in_chat_config(self):
+        """
+        This test checks the functionality of the 'chat' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
+        in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'chat' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                chatConfig = ChatConfig(where={"attribute": "value"})
+                answer = self.app.chat("Test chat", chatConfig)
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
+        self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
+        self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
+        mock_answer.assert_called_once()

+ 62 - 0
tests/embedchain/test_query.py

@@ -73,3 +73,65 @@ class TestApp(unittest.TestCase):
         messages_arg = mock_create.call_args.kwargs["messages"]
         self.assertEqual(messages_arg[0]["role"], "system")
         self.assertEqual(messages_arg[0]["content"], "Test system prompt")
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_query_with_where_in_params(self):
+        """
+        This test checks the functionality of the 'query' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'query' method is expected to call 'retrieve_from_database' with the where filter  and
+        'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'query' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                answer = self.app.query("Test query", where={"attribute": "value"})
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
+        self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
+        mock_answer.assert_called_once()
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_query_with_where_in_query_config(self):
+        """
+        This test checks the functionality of the 'query' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'query' method is expected to call 'retrieve_from_database' with the where filter  and
+        'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'query' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                queryConfig = QueryConfig(where={"attribute": "value"})
+                answer = self.app.query("Test query", queryConfig)
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
+        self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
+        self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
+        mock_answer.assert_called_once()