test_generate_prompt.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import unittest
  2. from string import Template
  3. from embedchain import App
  4. from embedchain.config import AppConfig, QueryConfig
  5. class TestGeneratePrompt(unittest.TestCase):
  6. def setUp(self):
  7. self.app = App(config=AppConfig(collect_metrics=False))
  8. def test_generate_prompt_with_template(self):
  9. """
  10. Tests that the generate_prompt method correctly formats the prompt using
  11. a custom template provided in the QueryConfig instance.
  12. This test sets up a scenario with an input query and a list of contexts,
  13. and a custom template, and then calls generate_prompt. It checks that the
  14. returned prompt correctly incorporates all the contexts and the query into
  15. the format specified by the template.
  16. """
  17. # Setup
  18. input_query = "Test query"
  19. contexts = ["Context 1", "Context 2", "Context 3"]
  20. template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
  21. config = QueryConfig(template=Template(template))
  22. # Execute
  23. result = self.app.generate_prompt(input_query, contexts, config)
  24. # Assert
  25. expected_result = (
  26. "You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
  27. )
  28. self.assertEqual(result, expected_result)
  29. def test_generate_prompt_with_contexts_list(self):
  30. """
  31. Tests that the generate_prompt method correctly handles a list of contexts.
  32. This test sets up a scenario with an input query and a list of contexts,
  33. and then calls generate_prompt. It checks that the returned prompt
  34. correctly includes all the contexts and the query.
  35. """
  36. # Setup
  37. input_query = "Test query"
  38. contexts = ["Context 1", "Context 2", "Context 3"]
  39. config = QueryConfig()
  40. # Execute
  41. result = self.app.generate_prompt(input_query, contexts, config)
  42. # Assert
  43. expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
  44. self.assertEqual(result, expected_result)
  45. def test_generate_prompt_with_history(self):
  46. """
  47. Test the 'generate_prompt' method with QueryConfig containing a history attribute.
  48. """
  49. config = QueryConfig(history=["Past context 1", "Past context 2"])
  50. config.template = Template("Context: $context | Query: $query | History: $history")
  51. prompt = self.app.generate_prompt("Test query", ["Test context"], config)
  52. expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
  53. self.assertEqual(prompt, expected_prompt)