Browse Source

feat: where filter in vector database (#518)

sw8fbar 1 year ago
parent
commit
3e66ddf69a

+ 5 - 6
embedchain/apps/PersonApp.py

@@ -4,8 +4,7 @@ from embedchain.apps.App import App
 from embedchain.apps.OpenSourceApp import OpenSourceApp
 from embedchain.apps.OpenSourceApp import OpenSourceApp
 from embedchain.config import ChatConfig, QueryConfig
 from embedchain.config import ChatConfig, QueryConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
-from embedchain.config.QueryConfig import (DEFAULT_PROMPT,
-                                           DEFAULT_PROMPT_WITH_HISTORY)
+from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
 from embedchain.helper_classes.json_serializable import register_deserializable
 from embedchain.helper_classes.json_serializable import register_deserializable
 
 
 
 
@@ -60,12 +59,12 @@ class PersonApp(EmbedChainPersonApp, App):
     """
     """
 
 
     def query(self, input_query, config: QueryConfig = None, dry_run=False):
     def query(self, input_query, config: QueryConfig = None, dry_run=False):
-        config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
-        return super().query(input_query, config, dry_run)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
+        return super().query(input_query, config, dry_run, where=None)
 
 
-    def chat(self, input_query, config: ChatConfig = None, dry_run=False):
+    def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
         config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
         config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
-        return super().chat(input_query, config, dry_run)
+        return super().chat(input_query, config, dry_run, where)
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 3 - 0
embedchain/config/ChatConfig.py

@@ -38,6 +38,7 @@ class ChatConfig(QueryConfig):
         stream: bool = False,
         stream: bool = False,
         deployment_name=None,
         deployment_name=None,
         system_prompt: Optional[str] = None,
         system_prompt: Optional[str] = None,
+        where=None,
     ):
     ):
         """
         """
         Initializes the ChatConfig instance.
         Initializes the ChatConfig instance.
@@ -57,6 +58,7 @@ class ChatConfig(QueryConfig):
         :param stream: Optional. Control if response is streamed back to the user
         :param stream: Optional. Control if response is streamed back to the user
         :param deployment_name: t.b.a.
         :param deployment_name: t.b.a.
         :param system_prompt: Optional. System prompt string.
         :param system_prompt: Optional. System prompt string.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :raises ValueError: If the template is not valid as template should contain
         :raises ValueError: If the template is not valid as template should contain
         $context and $query and $history
         $context and $query and $history
         """
         """
@@ -77,6 +79,7 @@ class ChatConfig(QueryConfig):
             stream=stream,
             stream=stream,
             deployment_name=deployment_name,
             deployment_name=deployment_name,
             system_prompt=system_prompt,
             system_prompt=system_prompt,
+            where=where,
         )
         )
 
 
     def set_history(self, history):
     def set_history(self, history):

+ 3 - 0
embedchain/config/QueryConfig.py

@@ -67,6 +67,7 @@ class QueryConfig(BaseConfig):
         stream: bool = False,
         stream: bool = False,
         deployment_name=None,
         deployment_name=None,
         system_prompt: Optional[str] = None,
         system_prompt: Optional[str] = None,
+        where=None,
     ):
     ):
         """
         """
         Initializes the QueryConfig instance.
         Initializes the QueryConfig instance.
@@ -87,6 +88,7 @@ class QueryConfig(BaseConfig):
         :param stream: Optional. Control if response is streamed back to user
         :param stream: Optional. Control if response is streamed back to user
         :param deployment_name: t.b.a.
         :param deployment_name: t.b.a.
         :param system_prompt: Optional. System prompt string.
         :param system_prompt: Optional. System prompt string.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :raises ValueError: If the template is not valid as template should
         :raises ValueError: If the template is not valid as template should
         contain $context and $query (and optionally $history).
         contain $context and $query (and optionally $history).
         """
         """
@@ -127,6 +129,7 @@ class QueryConfig(BaseConfig):
         if not isinstance(stream, bool):
         if not isinstance(stream, bool):
             raise ValueError("`stream` should be bool")
             raise ValueError("`stream` should be bool")
         self.stream = stream
         self.stream = stream
+        self.where = where
 
 
     def validate_template(self, template: Template):
     def validate_template(self, template: Template):
         """
         """

+ 19 - 6
embedchain/embedchain.py

@@ -250,16 +250,27 @@ class EmbedChain(JSONSerializable):
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def retrieve_from_database(self, input_query, config: QueryConfig):
+    def retrieve_from_database(self, input_query, config: QueryConfig, where=None):
         """
         """
         Queries the vector database based on the given input query.
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
         Gets relevant doc based on the query
 
 
         :param input_query: The query to use.
         :param input_query: The query to use.
         :param config: The query configuration.
         :param config: The query configuration.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :return: The content of the document that matched your query.
         :return: The content of the document that matched your query.
         """
         """
-        where = {"app_id": self.config.id} if self.config.id is not None else {}  # optional filter
+
+        if where is not None:
+            where = where
+        elif config is not None and config.where is not None:
+            where = config.where
+        else:
+            where = {}
+
+        if self.config.id is not None:
+            where.update({"app_id": self.config.id})
+
         contents = self.db.query(
         contents = self.db.query(
             input_query=input_query,
             input_query=input_query,
             n_results=config.number_documents,
             n_results=config.number_documents,
@@ -311,7 +322,7 @@ class EmbedChain(JSONSerializable):
         logging.info(f"Access search to get answers for {input_query}")
         logging.info(f"Access search to get answers for {input_query}")
         return search.run(input_query)
         return search.run(input_query)
 
 
-    def query(self, input_query, config: QueryConfig = None, dry_run=False):
+    def query(self, input_query, config: QueryConfig = None, dry_run=False, where=None):
         """
         """
         Queries the vector database based on the given input query.
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -326,6 +337,7 @@ class EmbedChain(JSONSerializable):
         by the vector database's doc retrieval.
         by the vector database's doc retrieval.
         The only thing the dry run does not consider is the cut-off due to
         The only thing the dry run does not consider is the cut-off due to
         the `max_tokens` parameter.
         the `max_tokens` parameter.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :return: The answer to the query.
         :return: The answer to the query.
         """
         """
         if config is None:
         if config is None:
@@ -336,7 +348,7 @@ class EmbedChain(JSONSerializable):
         k = {}
         k = {}
         if self.online:
         if self.online:
             k["web_search_result"] = self.access_search_and_get_results(input_query)
             k["web_search_result"] = self.access_search_and_get_results(input_query)
-        contexts = self.retrieve_from_database(input_query, config)
+        contexts = self.retrieve_from_database(input_query, config, where)
         prompt = self.generate_prompt(input_query, contexts, config, **k)
         prompt = self.generate_prompt(input_query, contexts, config, **k)
         logging.info(f"Prompt: {prompt}")
         logging.info(f"Prompt: {prompt}")
 
 
@@ -362,7 +374,7 @@ class EmbedChain(JSONSerializable):
             yield chunk
             yield chunk
         logging.info(f"Answer: {streamed_answer}")
         logging.info(f"Answer: {streamed_answer}")
 
 
-    def chat(self, input_query, config: ChatConfig = None, dry_run=False):
+    def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
         """
         """
         Queries the vector database on the given input query.
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
         Gets relevant doc based on the query and then passes it to an
@@ -378,6 +390,7 @@ class EmbedChain(JSONSerializable):
         by the vector database's doc retrieval.
         by the vector database's doc retrieval.
         The only thing the dry run does not consider is the cut-off due to
         The only thing the dry run does not consider is the cut-off due to
         the `max_tokens` parameter.
         the `max_tokens` parameter.
+        :param where: Optional. A dictionary of key-value pairs to filter the database results.
         :return: The answer to the query.
         :return: The answer to the query.
         """
         """
         if config is None:
         if config is None:
@@ -388,7 +401,7 @@ class EmbedChain(JSONSerializable):
         k = {}
         k = {}
         if self.online:
         if self.online:
             k["web_search_result"] = self.access_search_and_get_results(input_query)
             k["web_search_result"] = self.access_search_and_get_results(input_query)
-        contexts = self.retrieve_from_database(input_query, config)
+        contexts = self.retrieve_from_database(input_query, config, where)
 
 
         chat_history = self.memory.load_memory_variables({})["history"]
         chat_history = self.memory.load_memory_variables({})["history"]
 
 

+ 64 - 2
tests/embedchain/test_chat.py

@@ -1,9 +1,9 @@
 import os
 import os
 import unittest
 import unittest
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
 
 
 from embedchain import App
 from embedchain import App
-from embedchain.config import AppConfig
+from embedchain.config import AppConfig, ChatConfig
 
 
 
 
 class TestApp(unittest.TestCase):
 class TestApp(unittest.TestCase):
@@ -35,3 +35,65 @@ class TestApp(unittest.TestCase):
         second_answer = app.chat("Test query 2")
         second_answer = app.chat("Test query 2")
         self.assertEqual(second_answer, "Test answer")
         self.assertEqual(second_answer, "Test answer")
         self.assertEqual(len(app.memory.chat_memory.messages), 4)
         self.assertEqual(len(app.memory.chat_memory.messages), 4)
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_chat_with_where_in_params(self):
+        """
+        This test checks the functionality of the 'chat' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'chat' method is expected to call 'retrieve_from_database' with the where filter  and
+        'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'chat' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                answer = self.app.chat("Test chat", where={"attribute": "value"})
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
+        self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
+        mock_answer.assert_called_once()
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_chat_with_where_in_chat_config(self):
+        """
+        This test checks the functionality of the 'chat' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'chat' method is expected to call 'retrieve_from_database' with the where filter specified
+        in the QueryConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'chat' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                chatConfig = ChatConfig(where={"attribute": "value"})
+                answer = self.app.chat("Test chat", chatConfig)
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test chat")
+        self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
+        self.assertIsInstance(mock_retrieve.call_args[0][1], ChatConfig)
+        mock_answer.assert_called_once()

+ 62 - 0
tests/embedchain/test_query.py

@@ -73,3 +73,65 @@ class TestApp(unittest.TestCase):
         messages_arg = mock_create.call_args.kwargs["messages"]
         messages_arg = mock_create.call_args.kwargs["messages"]
         self.assertEqual(messages_arg[0]["role"], "system")
         self.assertEqual(messages_arg[0]["role"], "system")
         self.assertEqual(messages_arg[0]["content"], "Test system prompt")
         self.assertEqual(messages_arg[0]["content"], "Test system prompt")
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_query_with_where_in_params(self):
+        """
+        This test checks the functionality of the 'query' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'query' method is expected to call 'retrieve_from_database' with the where filter  and
+        'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'query' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                answer = self.app.query("Test query", where={"attribute": "value"})
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
+        self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
+        mock_answer.assert_called_once()
+
+    @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
+    def test_query_with_where_in_query_config(self):
+        """
+        This test checks the functionality of the 'query' method in the App class.
+        It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
+        a where filter and 'get_llm_model_answer' returns an expected answer string.
+
+        The 'query' method is expected to call 'retrieve_from_database' with the where filter  and
+        'get_llm_model_answer' methods appropriately and return the right answer.
+
+        Key assumptions tested:
+        - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
+            QueryConfig.
+        - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
+        - 'query' method returns the value it received from 'get_llm_model_answer'.
+
+        The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
+        'get_llm_model_answer' methods.
+        """
+        with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
+            mock_retrieve.return_value = ["Test context"]
+            with patch.object(self.app, "get_llm_model_answer") as mock_answer:
+                mock_answer.return_value = "Test answer"
+                queryConfig = QueryConfig(where={"attribute": "value"})
+                answer = self.app.query("Test query", queryConfig)
+
+        self.assertEqual(answer, "Test answer")
+        self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
+        self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
+        self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
+        mock_answer.assert_called_once()