Parcourir la source

Migrate from `template` to `prompt` arg while keeping backward compatibility (#1066)

Sidharth Mohanty il y a 1 an
Parent
commit
d9d529987e

+ 1 - 1
configs/full-stack.yaml

@@ -15,7 +15,7 @@ llm:
     max_tokens: 1000
     top_p: 1
     stream: false
-    template: |
+    prompt: |
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
 

+ 4 - 4
docs/api-reference/advanced/configuration.mdx

@@ -26,7 +26,7 @@ llm:
     top_p: 1
     stream: false
     api_key: sk-xxx
-    template: |
+    prompt: |
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
 
@@ -73,7 +73,7 @@ chunker:
       "max_tokens": 1000,
       "top_p": 1,
       "stream": false,
-      "template": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
+      "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
       "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
       "api_key": "sk-xxx"
     }
@@ -117,7 +117,7 @@ config = {
             'max_tokens': 1000,
             'top_p': 1,
             'stream': False,
-            'template': (
+            'prompt': (
                 "Use the following pieces of context to answer the query at the end.\n"
                 "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
                 "$context\n\nQuery: $query\n\nHelpful Answer:"
@@ -170,7 +170,7 @@ Alright, let's dive into what each key means in the yaml config above:
         - `max_tokens` (Integer): Controls how many tokens are used in the response.
         - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
         - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
-        - `template` (String): A custom template for the prompt that the model uses to generate responses.
+        - `prompt` (String): A prompt for the model to follow when generating responses, requires $context and $query variables.
         - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
         -  `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1

+ 1 - 1
docs/examples/rest-api/create.mdx

@@ -37,7 +37,7 @@ llm:
     max_tokens: 1000
     top_p: 1
     stream: false
-    template: |
+    prompt: |
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
 

+ 31 - 18
embedchain/config/llm/base.py

@@ -1,3 +1,4 @@
+import logging
 import re
 from string import Template
 from typing import Any, Dict, List, Optional
@@ -59,6 +60,7 @@ class BaseLlmConfig(BaseConfig):
         self,
         number_documents: int = 3,
         template: Optional[Template] = None,
+        prompt: Optional[Template] = None,
         model: Optional[str] = None,
         temperature: float = 0,
         max_tokens: int = 1000,
@@ -80,8 +82,11 @@ class BaseLlmConfig(BaseConfig):
         context, defaults to 1
         :type number_documents: int, optional
         :param template:  The `Template` instance to use as a template for
-        prompt, defaults to None
+        prompt, defaults to None (deprecated)
         :type template: Optional[Template], optional
+        :param prompt: The `Template` instance to use as a template for
+        prompt, defaults to None
+        :type prompt: Optional[Template], optional
         :param model: Controls the OpenAI model used, defaults to None
         :type model: Optional[str], optional
         :param temperature:  Controls the randomness of the model's output.
@@ -106,8 +111,16 @@ class BaseLlmConfig(BaseConfig):
         contain $context and $query (and optionally $history)
         :raises ValueError: Stream is not boolean
         """
-        if template is None:
-            template = DEFAULT_PROMPT_TEMPLATE
+        if template is not None:
+            logging.warning(
+                "The `template` argument is deprecated and will be removed in a future version. "
+                + "Please use `prompt` instead."
+            )
+            if prompt is None:
+                prompt = template
+
+        if prompt is None:
+            prompt = DEFAULT_PROMPT_TEMPLATE
 
         self.number_documents = number_documents
         self.temperature = temperature
@@ -120,37 +133,37 @@ class BaseLlmConfig(BaseConfig):
         self.callbacks = callbacks
         self.api_key = api_key
 
-        if type(template) is str:
-            template = Template(template)
+        if type(prompt) is str:
+            prompt = Template(prompt)
 
-        if self.validate_template(template):
-            self.template = template
+        if self.validate_prompt(prompt):
+            self.prompt = prompt
         else:
-            raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
+            raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
 
         if not isinstance(stream, bool):
             raise ValueError("`stream` should be bool")
         self.stream = stream
         self.where = where
 
-    def validate_template(self, template: Template) -> bool:
+    def validate_prompt(self, prompt: Template) -> bool:
         """
-        validate the template
+        validate the prompt
 
-        :param template: the template to validate
-        :type template: Template
+        :param prompt: the prompt to validate
+        :type prompt: Template
         :return: valid (true) or invalid (false)
         :rtype: bool
         """
-        return re.search(query_re, template.template) and re.search(context_re, template.template)
+        return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
 
-    def _validate_template_history(self, template: Template) -> bool:
+    def _validate_prompt_history(self, prompt: Template) -> bool:
         """
-        validate the template with history
+        validate the prompt with history
 
-        :param template: the template to validate
-        :type template: Template
+        :param prompt: the prompt to validate
+        :type prompt: Template
         :return: valid (true) or invalid (false)
         :rtype: bool
         """
-        return re.search(history_re, template.template)
+        return re.search(history_re, prompt.template)

+ 14 - 14
embedchain/llm/base.py

@@ -74,19 +74,19 @@ class BaseLlm(JSONSerializable):
         if web_search_result:
             context_string = self._append_search_and_context(context_string, web_search_result)
 
-        template_contains_history = self.config._validate_template_history(self.config.template)
-        if template_contains_history:
-            # Template contains history
+        prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
+        if prompt_contains_history:
+            # Prompt contains history
             # If there is no history yet, we insert `- no history -`
-            prompt = self.config.template.substitute(
+            prompt = self.config.prompt.substitute(
                 context=context_string, query=input_query, history=self.history or "- no history -"
             )
-        elif self.history and not template_contains_history:
-            # History is present, but not included in the template.
-            # check if it's the default template without history
+        elif self.history and not prompt_contains_history:
+            # History is present, but not included in the prompt.
+            # check if it's the default prompt without history
             if (
-                not self.config._validate_template_history(self.config.template)
-                and self.config.template.template == DEFAULT_PROMPT
+                not self.config._validate_prompt_history(self.config.prompt)
+                and self.config.prompt.template == DEFAULT_PROMPT
             ):
                 # swap in the template with history
                 prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
@@ -95,12 +95,12 @@ class BaseLlm(JSONSerializable):
             else:
                 # If we can't swap in the default, we still proceed but tell users that the history is ignored.
                 logging.warning(
-                    "Your bot contains a history, but template does not include `$history` key. History is ignored."
+                    "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
                 )
-                prompt = self.config.template.substitute(context=context_string, query=input_query)
+                prompt = self.config.prompt.substitute(context=context_string, query=input_query)
         else:
             # basic use case, no history.
-            prompt = self.config.template.substitute(context=context_string, query=input_query)
+            prompt = self.config.prompt.substitute(context=context_string, query=input_query)
         return prompt
 
     def _append_search_and_context(self, context: str, web_search_result: str) -> str:
@@ -191,7 +191,7 @@ class BaseLlm(JSONSerializable):
                 return contexts
 
             if self.is_docs_site_instance:
-                self.config.template = DOCS_SITE_PROMPT_TEMPLATE
+                self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
                 self.config.number_documents = 5
             k = {}
             if self.online:
@@ -242,7 +242,7 @@ class BaseLlm(JSONSerializable):
                 self.config = config
 
             if self.is_docs_site_instance:
-                self.config.template = DOCS_SITE_PROMPT_TEMPLATE
+                self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
                 self.config.number_documents = 5
             k = {}
             if self.online:

+ 1 - 0
embedchain/utils.py

@@ -396,6 +396,7 @@ def validate_config(config_data):
                     Optional("top_p"): Or(float, int),
                     Optional("stream"): bool,
                     Optional("template"): str,
+                    Optional("prompt"): str,
                     Optional("system_prompt"): str,
                     Optional("deployment_name"): str,
                     Optional("where"): dict,

+ 1 - 1
tests/helper_classes/test_json_serializable.py

@@ -76,4 +76,4 @@ class TestJsonSerializable(unittest.TestCase):
         config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
         s = config.serialize()
         new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
-        self.assertEqual(config.template.template, new_config.template.template)
+        self.assertEqual(config.prompt.template, new_config.prompt.template)

+ 1 - 1
tests/llm/test_base_llm.py

@@ -25,7 +25,7 @@ def test_is_stream_bool():
 def test_template_string_gets_converted_to_Template_instance():
     config = BaseLlmConfig(template="test value $query $context")
     llm = BaseLlm(config=config)
-    assert isinstance(llm.config.template, Template)
+    assert isinstance(llm.config.prompt, Template)
 
 
 def test_is_get_llm_model_answer_implemented():

+ 2 - 2
tests/llm/test_generate_prompt.py

@@ -53,7 +53,7 @@ class TestGeneratePrompt(unittest.TestCase):
         result = self.app.llm.generate_prompt(input_query, contexts)
 
         # Assert
-        expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
+        expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
         self.assertEqual(result, expected_result)
 
     def test_generate_prompt_with_history(self):
@@ -61,7 +61,7 @@ class TestGeneratePrompt(unittest.TestCase):
         Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
         """
         config = BaseLlmConfig()
-        config.template = Template("Context: $context | Query: $query | History: $history")
+        config.prompt = Template("Context: $context | Query: $query | History: $history")
         self.app.llm.config = config
         self.app.llm.set_history(["Past context 1", "Past context 2"])
         prompt = self.app.llm.generate_prompt("Test query", ["Test context"])