test_query.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import os
  2. import unittest
  3. from unittest.mock import MagicMock, patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig, QueryConfig
  6. class TestApp(unittest.TestCase):
  7. os.environ["OPENAI_API_KEY"] = "test_key"
  8. def setUp(self):
  9. self.app = App(config=AppConfig(collect_metrics=False))
  10. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  11. def test_query(self):
  12. """
  13. This test checks the functionality of the 'query' method in the App class.
  14. It simulates a scenario where the 'retrieve_from_database' method returns a context list and
  15. 'get_llm_model_answer' returns an expected answer string.
  16. The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
  17. appropriately and return the right answer.
  18. Key assumptions tested:
  19. - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
  20. QueryConfig.
  21. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
  22. - 'query' method returns the value it received from 'get_llm_model_answer'.
  23. The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
  24. 'get_llm_model_answer' methods.
  25. """
  26. with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
  27. mock_retrieve.return_value = ["Test context"]
  28. with patch.object(self.app, "get_llm_model_answer") as mock_answer:
  29. mock_answer.return_value = "Test answer"
  30. answer = self.app.query("Test query")
  31. self.assertEqual(answer, "Test answer")
  32. self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
  33. self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
  34. mock_answer.assert_called_once()
  35. @patch("openai.ChatCompletion.create")
  36. def test_query_config_app_passing(self, mock_create):
  37. mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
  38. config = AppConfig()
  39. chat_config = QueryConfig(system_prompt="Test system prompt")
  40. app = App(config=config)
  41. app.get_llm_model_answer("Test query", chat_config)
  42. # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
  43. messages_arg = mock_create.call_args.kwargs["messages"]
  44. self.assertEqual(messages_arg[0]["role"], "system")
  45. self.assertEqual(messages_arg[0]["content"], "Test system prompt")
  46. # TODO: Add tests for other config variables
  47. @patch("openai.ChatCompletion.create")
  48. def test_app_passing(self, mock_create):
  49. mock_create.return_value = {"choices": [{"message": {"content": "response"}}]} # Mock response
  50. config = AppConfig()
  51. chat_config = QueryConfig()
  52. app = App(config=config, system_prompt="Test system prompt")
  53. app.get_llm_model_answer("Test query", chat_config)
  54. # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
  55. messages_arg = mock_create.call_args.kwargs["messages"]
  56. self.assertEqual(messages_arg[0]["role"], "system")
  57. self.assertEqual(messages_arg[0]["content"], "Test system prompt")
  58. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  59. def test_query_with_where_in_params(self):
  60. """
  61. This test checks the functionality of the 'query' method in the App class.
  62. It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
  63. a where filter and 'get_llm_model_answer' returns an expected answer string.
  64. The 'query' method is expected to call 'retrieve_from_database' with the where filter and
  65. 'get_llm_model_answer' methods appropriately and return the right answer.
  66. Key assumptions tested:
  67. - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
  68. QueryConfig.
  69. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
  70. - 'query' method returns the value it received from 'get_llm_model_answer'.
  71. The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
  72. 'get_llm_model_answer' methods.
  73. """
  74. with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
  75. mock_retrieve.return_value = ["Test context"]
  76. with patch.object(self.app, "get_llm_model_answer") as mock_answer:
  77. mock_answer.return_value = "Test answer"
  78. answer = self.app.query("Test query", where={"attribute": "value"})
  79. self.assertEqual(answer, "Test answer")
  80. self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
  81. self.assertEqual(mock_retrieve.call_args[0][2], {"attribute": "value"})
  82. mock_answer.assert_called_once()
  83. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  84. def test_query_with_where_in_query_config(self):
  85. """
  86. This test checks the functionality of the 'query' method in the App class.
  87. It simulates a scenario where the 'retrieve_from_database' method returns a context list based on
  88. a where filter and 'get_llm_model_answer' returns an expected answer string.
  89. The 'query' method is expected to call 'retrieve_from_database' with the where filter and
  90. 'get_llm_model_answer' methods appropriately and return the right answer.
  91. Key assumptions tested:
  92. - 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
  93. QueryConfig.
  94. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
  95. - 'query' method returns the value it received from 'get_llm_model_answer'.
  96. The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
  97. 'get_llm_model_answer' methods.
  98. """
  99. with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
  100. mock_retrieve.return_value = ["Test context"]
  101. with patch.object(self.app, "get_llm_model_answer") as mock_answer:
  102. mock_answer.return_value = "Test answer"
  103. queryConfig = QueryConfig(where={"attribute": "value"})
  104. answer = self.app.query("Test query", queryConfig)
  105. self.assertEqual(answer, "Test answer")
  106. self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
  107. self.assertEqual(mock_retrieve.call_args[0][1].where, {"attribute": "value"})
  108. self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
  109. mock_answer.assert_called_once()