test_dryrun.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import unittest
  3. from string import Template
  4. from unittest.mock import patch
  5. from embedchain import App
  6. from embedchain.embedchain import QueryConfig
  7. class TestApp(unittest.TestCase):
  8. os.environ["OPENAI_API_KEY"] = "test_key"
  9. def setUp(self):
  10. self.app = App()
  11. @patch("logging.info")
  12. def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
  13. """
  14. Test that the 'query' method logs the same prompt as the 'dry_run' method.
  15. This is the only way I found to test the prompt in query, that's not returned.
  16. """
  17. with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
  18. mock_retrieve.return_value = ["Test context"]
  19. input_query = "Test query"
  20. config = QueryConfig(
  21. number_documents=3,
  22. template=Template("Question: $query, context: $context, history: $history"),
  23. history=["Past context 1", "Past context 2"],
  24. )
  25. with patch.object(self.app, "get_answer_from_llm"):
  26. self.app.dry_run(input_query, config)
  27. self.app.query(input_query, config)
  28. # Access the log messages captured during the execution
  29. logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
  30. # Extract the prompts from the log messages
  31. dry_run_prompt = self.extract_prompt(logged_messages[0])
  32. query_prompt = self.extract_prompt(logged_messages[1])
  33. # Perform assertions on the prompts
  34. self.assertEqual(dry_run_prompt, query_prompt)
  35. def extract_prompt(self, log_message):
  36. """
  37. Extracts the prompt value from the log message.
  38. Adjust this method based on the log message format in your implementation.
  39. """
  40. # Modify this logic based on your log message format
  41. prefix = "Prompt: "
  42. return log_message.split(prefix, 1)[1]