test_chat.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import unittest
  3. from unittest.mock import patch
  4. from embedchain import App
  5. from embedchain.config import AppConfig
  6. class TestApp(unittest.TestCase):
  7. os.environ["OPENAI_API_KEY"] = "test_key"
  8. def setUp(self):
  9. self.app = App(config=AppConfig(collect_metrics=False))
  10. @patch("embedchain.embedchain.memory", autospec=True)
  11. @patch.object(App, "retrieve_from_database", return_value=["Test context"])
  12. @patch.object(App, "get_answer_from_llm", return_value="Test answer")
  13. def test_chat_with_memory(self, mock_answer, mock_retrieve, mock_memory):
  14. """
  15. This test checks the functionality of the 'chat' method in the App class with respect to the chat history
  16. memory.
  17. The 'chat' method is called twice. The first call initializes the chat history memory.
  18. The second call is expected to use the chat history from the first call.
  19. Key assumptions tested:
  20. - After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
  21. called with correct arguments, adding the correct chat history.
  22. - During the second call, the 'chat' method uses the chat history from the first call.
  23. The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
  24. 'memory' methods.
  25. """
  26. mock_memory.load_memory_variables.return_value = {"history": []}
  27. app = App()
  28. # First call to chat
  29. first_answer = app.chat("Test query 1")
  30. self.assertEqual(first_answer, "Test answer")
  31. mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 1")
  32. mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")
  33. mock_memory.chat_memory.add_user_message.reset_mock()
  34. mock_memory.chat_memory.add_ai_message.reset_mock()
  35. # Second call to chat
  36. second_answer = app.chat("Test query 2")
  37. self.assertEqual(second_answer, "Test answer")
  38. mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 2")
  39. mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")