12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- import unittest
- from string import Template
- from embedchain import App
- from embedchain.config import AppConfig, BaseLlmConfig
- class TestGeneratePrompt(unittest.TestCase):
- def setUp(self):
- self.app = App(config=AppConfig(collect_metrics=False))
- def test_generate_prompt_with_template(self):
- """
- Tests that the generate_prompt method correctly formats the prompt using
- a custom template provided in the BaseLlmConfig instance.
- This test sets up a scenario with an input query and a list of contexts,
- and a custom template, and then calls generate_prompt. It checks that the
- returned prompt correctly incorporates all the contexts and the query into
- the format specified by the template.
- """
- # Setup
- input_query = "Test query"
- contexts = ["Context 1", "Context 2", "Context 3"]
- template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
- config = BaseLlmConfig(template=Template(template))
- self.app.llm.config = config
- # Execute
- result = self.app.llm.generate_prompt(input_query, contexts)
- # Assert
- expected_result = (
- "You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
- )
- self.assertEqual(result, expected_result)
- def test_generate_prompt_with_contexts_list(self):
- """
- Tests that the generate_prompt method correctly handles a list of contexts.
- This test sets up a scenario with an input query and a list of contexts,
- and then calls generate_prompt. It checks that the returned prompt
- correctly includes all the contexts and the query.
- """
- # Setup
- input_query = "Test query"
- contexts = ["Context 1", "Context 2", "Context 3"]
- config = BaseLlmConfig()
- # Execute
- self.app.llm.config = config
- result = self.app.llm.generate_prompt(input_query, contexts)
- # Assert
- expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
- self.assertEqual(result, expected_result)
- def test_generate_prompt_with_history(self):
- """
- Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
- """
- config = BaseLlmConfig()
- config.prompt = Template("Context: $context | Query: $query | History: $history")
- self.app.llm.config = config
- self.app.llm.set_history(["Past context 1", "Past context 2"])
- prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
- expected_prompt = "Context: Test context | Query: Test query | History: Past context 1\nPast context 2"
- self.assertEqual(prompt, expected_prompt)
|