12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import os
- import unittest
- from string import Template
- from unittest.mock import patch
- from embedchain import App
- from embedchain.embedchain import QueryConfig
- class TestApp(unittest.TestCase):
- os.environ["OPENAI_API_KEY"] = "test_key"
- def setUp(self):
- self.app = App()
- @patch("logging.info")
- def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
- """
- Test that the 'query' method logs the same prompt as the 'dry_run' method.
- This is the only way I found to test the prompt in query, that's not returned.
- """
- with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
- mock_retrieve.return_value = ["Test context"]
- input_query = "Test query"
- config = QueryConfig(
- number_documents=3,
- template=Template("Question: $query, context: $context, history: $history"),
- history=["Past context 1", "Past context 2"],
- )
- with patch.object(self.app, "get_answer_from_llm"):
- self.app.dry_run(input_query, config)
- self.app.query(input_query, config)
- # Access the log messages captured during the execution
- logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]
- # Extract the prompts from the log messages
- dry_run_prompt = self.extract_prompt(logged_messages[0])
- query_prompt = self.extract_prompt(logged_messages[1])
- # Perform assertions on the prompts
- self.assertEqual(dry_run_prompt, query_prompt)
- def extract_prompt(self, log_message):
- """
- Extracts the prompt value from the log message.
- Adjust this method based on the log message format in your implementation.
- """
- # Modify this logic based on your log message format
- prefix = "Prompt: "
- return log_message.split(prefix, 1)[1]
|