test_generate_prompt.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import unittest
  2. from string import Template
  3. from embedchain import App
  4. from embedchain.config import AppConfig, BaseLlmConfig
  5. class TestGeneratePrompt(unittest.TestCase):
  6. def setUp(self):
  7. self.app = App(config=AppConfig(collect_metrics=False))
  8. def test_generate_prompt_with_template(self):
  9. """
  10. Tests that the generate_prompt method correctly formats the prompt using
  11. a custom template provided in the BaseLlmConfig instance.
  12. This test sets up a scenario with an input query and a list of contexts,
  13. and a custom template, and then calls generate_prompt. It checks that the
  14. returned prompt correctly incorporates all the contexts and the query into
  15. the format specified by the template.
  16. """
  17. # Setup
  18. input_query = "Test query"
  19. contexts = ["Context 1", "Context 2", "Context 3"]
  20. template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
  21. config = BaseLlmConfig(template=Template(template))
  22. self.app.llm.config = config
  23. # Execute
  24. result = self.app.llm.generate_prompt(input_query, contexts)
  25. # Assert
  26. expected_result = (
  27. "You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
  28. )
  29. self.assertEqual(result, expected_result)
  30. def test_generate_prompt_with_contexts_list(self):
  31. """
  32. Tests that the generate_prompt method correctly handles a list of contexts.
  33. This test sets up a scenario with an input query and a list of contexts,
  34. and then calls generate_prompt. It checks that the returned prompt
  35. correctly includes all the contexts and the query.
  36. """
  37. # Setup
  38. input_query = "Test query"
  39. contexts = ["Context 1", "Context 2", "Context 3"]
  40. config = BaseLlmConfig()
  41. # Execute
  42. self.app.llm.config = config
  43. result = self.app.llm.generate_prompt(input_query, contexts)
  44. # Assert
  45. expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
  46. self.assertEqual(result, expected_result)
  47. def test_generate_prompt_with_history(self):
  48. """
  49. Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
  50. """
  51. config = BaseLlmConfig()
  52. config.prompt = Template("Context: $context | Query: $query | History: $history")
  53. self.app.llm.config = config
  54. self.app.llm.set_history(["Past context 1", "Past context 2"])
  55. prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
  56. expected_prompt = "Context: Test context | Query: Test query | History: Past context 1\nPast context 2"
  57. self.assertEqual(prompt, expected_prompt)