test_query.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import os
  2. import unittest
  3. from unittest.mock import MagicMock, patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig, BaseLlmConfig
  6. class TestApp(unittest.TestCase):
  7. os.environ["OPENAI_API_KEY"] = "test_key"
  8. def setUp(self):
  9. self.app = App(config=AppConfig(collect_metrics=False))
  10. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  11. def test_query(self):
  12. """
  13. This test checks the functionality of the 'query' method in the App class.
  14. It simulates a scenario where the 'retrieve_from_database' method returns a context list and
  15. 'get_llm_model_answer' returns an expected answer string.
  16. The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
  17. appropriately and return the right answer.
  18. Key assumptions tested:
  19. - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
  20. LlmConfig.
  21. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
  22. - 'query' method returns the value it received from 'get_llm_model_answer'.
  23. The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
  24. 'get_llm_model_answer' methods.
  25. """
  26. with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
  27. mock_retrieve.return_value = ["Test context"]
  28. with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
  29. mock_answer.return_value = "Test answer"
  30. _answer = self.app.query(input_query="Test query")
  31. # Ensure retrieve_from_database was called
  32. mock_retrieve.assert_called_once()
  33. # Check the call arguments
  34. args, kwargs = mock_retrieve.call_args
  35. input_query_arg = kwargs.get("input_query")
  36. self.assertEqual(input_query_arg, "Test query")
  37. mock_answer.assert_called_once()
  38. @patch("openai.ChatCompletion.create")
  39. def test_query_config_app_passing(self, mock_create):
  40. mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
  41. config = AppConfig(collect_metrics=False)
  42. chat_config = BaseLlmConfig(system_prompt="Test system prompt")
  43. app = App(config=config, llm_config=chat_config)
  44. app.llm.get_llm_model_answer("Test query")
  45. # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
  46. messages_arg = mock_create.call_args.kwargs["messages"]
  47. self.assertTrue(messages_arg[0].get("role"), "system")
  48. self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
  49. self.assertTrue(messages_arg[1].get("role"), "user")
  50. self.assertEqual(messages_arg[1].get("content"), "Test query")
  51. # TODO: Add tests for other config variables
  52. @patch("openai.ChatCompletion.create")
  53. def test_app_passing(self, mock_create):
  54. mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
  55. config = AppConfig(collect_metrics=False)
  56. chat_config = BaseLlmConfig()
  57. app = App(config=config, llm_config=chat_config, system_prompt="Test system prompt")
  58. self.assertEqual(app.llm.config.system_prompt, "Test system prompt")
  59. app.llm.get_llm_model_answer("Test query")
  60. # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
  61. messages_arg = mock_create.call_args.kwargs["messages"]
  62. self.assertTrue(messages_arg[0].get("role"), "system")
  63. self.assertEqual(messages_arg[0].get("content"), "Test system prompt")
  64. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  65. def test_query_with_where_in_params(self):
  66. """
  67. This test checks the functionality of the 'query' method in the App class.
  68. It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
  69. a where filter and 'get_llm_model_answer' returns an expected answer string.
  70. The 'query' method is expected to call 'retrieve_from_database' with the where filter and
  71. 'get_llm_model_answer' methods appropriately and return the right answer.
  72. Key assumptions tested:
  73. - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
  74. LlmConfig.
  75. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
  76. - 'query' method returns the value it received from 'get_llm_model_answer'.
  77. The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
  78. 'get_llm_model_answer' methods.
  79. """
  80. with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
  81. mock_retrieve.return_value = ["Test context"]
  82. with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
  83. mock_answer.return_value = "Test answer"
  84. answer = self.app.query("Test query", where={"attribute": "value"})
  85. self.assertEqual(answer, "Test answer")
  86. _args, kwargs = mock_retrieve.call_args
  87. self.assertEqual(kwargs.get("input_query"), "Test query")
  88. self.assertEqual(kwargs.get("where"), {"attribute": "value"})
  89. mock_answer.assert_called_once()
  90. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  91. def test_query_with_where_in_query_config(self):
  92. """
  93. This test checks the functionality of the 'query' method in the App class.
  94. It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
  95. a where filter and 'get_llm_model_answer' returns an expected answer string.
  96. The 'query' method is expected to call 'retrieve_from_database' with the where filter and
  97. 'get_llm_model_answer' methods appropriately and return the right answer.
  98. Key assumptions tested:
  99. - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
  100. LlmConfig.
  101. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
  102. - 'query' method returns the value it received from 'get_llm_model_answer'.
  103. The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
  104. 'get_llm_model_answer' methods.
  105. """
  106. with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
  107. mock_answer.return_value = "Test answer"
  108. with patch.object(self.app.db, "query") as mock_database_query:
  109. mock_database_query.return_value = ["Test context"]
  110. llm_config = BaseLlmConfig(where={"attribute": "value"})
  111. answer = self.app.query("Test query", llm_config)
  112. self.assertEqual(answer, "Test answer")
  113. _args, kwargs = mock_database_query.call_args
  114. self.assertEqual(kwargs.get("input_query"), "Test query")
  115. self.assertEqual(kwargs.get("where"), {"attribute": "value"})
  116. mock_answer.assert_called_once()