test_chat.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. import unittest
  3. from unittest.mock import MagicMock, patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig, BaseLlmConfig
  6. from embedchain.llm.base import BaseLlm
  7. from embedchain.memory.base import ChatHistory
  8. from embedchain.memory.message import ChatMessage
  9. class TestApp(unittest.TestCase):
  10. def setUp(self):
  11. os.environ["OPENAI_API_KEY"] = "test_key"
  12. self.app = App(config=AppConfig(collect_metrics=False))
  13. @patch.object(App, "_retrieve_from_database", return_value=["Test context"])
  14. @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
  15. def test_chat_with_memory(self, mock_get_answer, mock_retrieve):
  16. """
  17. This test checks the functionality of the 'chat' method in the App class with respect to the chat history
  18. memory.
  19. The 'chat' method is called twice. The first call initializes the chat history memory.
  20. The second call is expected to use the chat history from the first call.
  21. Key assumptions tested:
  22. called with correct arguments, adding the correct chat history.
  23. - After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
  24. - During the second call, the 'chat' method uses the chat history from the first call.
  25. The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database', 'get_answer_from_llm' and
  26. 'memory' methods.
  27. """
  28. config = AppConfig(collect_metrics=False)
  29. app = App(config=config)
  30. with patch.object(BaseLlm, "add_history") as mock_history:
  31. first_answer = app.chat("Test query 1")
  32. self.assertEqual(first_answer, "Test answer")
  33. mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer", session_id="default")
  34. second_answer = app.chat("Test query 2", session_id="test_session")
  35. self.assertEqual(second_answer, "Test answer")
  36. mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer", session_id="test_session")
  37. @patch.object(App, "_retrieve_from_database", return_value=["Test context"])
  38. @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
  39. def test_template_replacement(self, mock_get_answer, mock_retrieve):
  40. """
  41. Tests that if a default template is used and it doesn't contain history,
  42. the default template is swapped in.
  43. Also tests that a dry run does not change the history
  44. """
  45. with patch.object(ChatHistory, "get") as mock_memory:
  46. mock_message = ChatMessage()
  47. mock_message.add_user_message("Test query 1")
  48. mock_message.add_ai_message("Test answer")
  49. mock_memory.return_value = [mock_message]
  50. config = AppConfig(collect_metrics=False)
  51. app = App(config=config)
  52. first_answer = app.chat("Test query 1")
  53. self.assertEqual(first_answer, "Test answer")
  54. self.assertEqual(len(app.llm.history), 1)
  55. history = app.llm.history
  56. dry_run = app.chat("Test query 2", dry_run=True)
  57. self.assertIn("History:", dry_run)
  58. self.assertEqual(history, app.llm.history)
  59. self.assertEqual(len(app.llm.history), 1)
  60. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  61. def test_chat_with_where_in_params(self):
  62. """
  63. Test where filter
  64. """
  65. with patch.object(self.app, "_retrieve_from_database") as mock_retrieve:
  66. mock_retrieve.return_value = ["Test context"]
  67. with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
  68. mock_answer.return_value = "Test answer"
  69. answer = self.app.chat("Test query", where={"attribute": "value"})
  70. self.assertEqual(answer, "Test answer")
  71. _args, kwargs = mock_retrieve.call_args
  72. self.assertEqual(kwargs.get("input_query"), "Test query")
  73. self.assertEqual(kwargs.get("where"), {"attribute": "value"})
  74. mock_answer.assert_called_once()
  75. @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
  76. def test_chat_with_where_in_chat_config(self):
  77. """
  78. This test checks the functionality of the 'chat' method in the App class.
  79. It simulates a scenario where the '_retrieve_from_database' method returns a context list based on
  80. a where filter and 'get_llm_model_answer' returns an expected answer string.
  81. The 'chat' method is expected to call '_retrieve_from_database' with the where filter specified
  82. in the BaseLlmConfig and 'get_llm_model_answer' methods appropriately and return the right answer.
  83. Key assumptions tested:
  84. - '_retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
  85. BaseLlmConfig.
  86. - 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
  87. - 'chat' method returns the value it received from 'get_llm_model_answer'.
  88. The test isolates the 'chat' method behavior by mocking out '_retrieve_from_database' and
  89. 'get_llm_model_answer' methods.
  90. """
  91. with patch.object(self.app.llm, "get_llm_model_answer") as mock_answer:
  92. mock_answer.return_value = "Test answer"
  93. with patch.object(self.app.db, "query") as mock_database_query:
  94. mock_database_query.return_value = ["Test context"]
  95. llm_config = BaseLlmConfig(where={"attribute": "value"})
  96. answer = self.app.chat("Test query", llm_config)
  97. self.assertEqual(answer, "Test answer")
  98. _args, kwargs = mock_database_query.call_args
  99. self.assertEqual(kwargs.get("input_query"), "Test query")
  100. where = kwargs.get("where")
  101. assert "app_id" in where
  102. assert "attribute" in where
  103. mock_answer.assert_called_once()