浏览代码

Migrate from `template` to `prompt` arg while keeping backward compatibility (#1066)

Sidharth Mohanty 1 年之前
父节点
当前提交
d9d529987e

+ 1 - 1
configs/full-stack.yaml

@@ -15,7 +15,7 @@ llm:
     max_tokens: 1000
     max_tokens: 1000
     top_p: 1
     top_p: 1
     stream: false
     stream: false
-    template: |
+    prompt: |
       Use the following pieces of context to answer the query at the end.
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
 
 

+ 4 - 4
docs/api-reference/advanced/configuration.mdx

@@ -26,7 +26,7 @@ llm:
     top_p: 1
     top_p: 1
     stream: false
     stream: false
     api_key: sk-xxx
     api_key: sk-xxx
-    template: |
+    prompt: |
       Use the following pieces of context to answer the query at the end.
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
 
 
@@ -73,7 +73,7 @@ chunker:
       "max_tokens": 1000,
       "max_tokens": 1000,
       "top_p": 1,
       "top_p": 1,
       "stream": false,
       "stream": false,
-      "template": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
+      "prompt": "Use the following pieces of context to answer the query at the end.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n$context\n\nQuery: $query\n\nHelpful Answer:",
       "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
       "system_prompt": "Act as William Shakespeare. Answer the following questions in the style of William Shakespeare.",
       "api_key": "sk-xxx"
       "api_key": "sk-xxx"
     }
     }
@@ -117,7 +117,7 @@ config = {
             'max_tokens': 1000,
             'max_tokens': 1000,
             'top_p': 1,
             'top_p': 1,
             'stream': False,
             'stream': False,
-            'template': (
+            'prompt': (
                 "Use the following pieces of context to answer the query at the end.\n"
                 "Use the following pieces of context to answer the query at the end.\n"
                 "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
                 "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n"
                 "$context\n\nQuery: $query\n\nHelpful Answer:"
                 "$context\n\nQuery: $query\n\nHelpful Answer:"
@@ -170,7 +170,7 @@ Alright, let's dive into what each key means in the yaml config above:
         - `max_tokens` (Integer): Controls how many tokens are used in the response.
         - `max_tokens` (Integer): Controls how many tokens are used in the response.
         - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
         - `top_p` (Float): Controls the diversity of word selection. A higher value (closer to 1) makes word selection more diverse.
         - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
-        - `template` (String): A custom template for the prompt that the model uses to generate responses.
+        - `prompt` (String): A prompt for the model to follow when generating responses, requires $context and $query variables.
         - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
         - `system_prompt` (String): A system prompt for the model to follow when generating responses, in this case, it's set to the style of William Shakespeare.
         -  `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         -  `stream` (Boolean): Controls if the response is streamed back to the user (set to false).
         - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1
         - `number_documents` (Integer): Number of documents to pull from the vectordb as context, defaults to 1

+ 1 - 1
docs/examples/rest-api/create.mdx

@@ -37,7 +37,7 @@ llm:
     max_tokens: 1000
     max_tokens: 1000
     top_p: 1
     top_p: 1
     stream: false
     stream: false
-    template: |
+    prompt: |
       Use the following pieces of context to answer the query at the end.
       Use the following pieces of context to answer the query at the end.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
       If you don't know the answer, just say that you don't know, don't try to make up an answer.
 
 

+ 31 - 18
embedchain/config/llm/base.py

@@ -1,3 +1,4 @@
+import logging
 import re
 import re
 from string import Template
 from string import Template
 from typing import Any, Dict, List, Optional
 from typing import Any, Dict, List, Optional
@@ -59,6 +60,7 @@ class BaseLlmConfig(BaseConfig):
         self,
         self,
         number_documents: int = 3,
         number_documents: int = 3,
         template: Optional[Template] = None,
         template: Optional[Template] = None,
+        prompt: Optional[Template] = None,
         model: Optional[str] = None,
         model: Optional[str] = None,
         temperature: float = 0,
         temperature: float = 0,
         max_tokens: int = 1000,
         max_tokens: int = 1000,
@@ -80,8 +82,11 @@ class BaseLlmConfig(BaseConfig):
         context, defaults to 1
         context, defaults to 1
         :type number_documents: int, optional
         :type number_documents: int, optional
         :param template:  The `Template` instance to use as a template for
         :param template:  The `Template` instance to use as a template for
-        prompt, defaults to None
+        prompt, defaults to None (deprecated)
         :type template: Optional[Template], optional
         :type template: Optional[Template], optional
+        :param prompt: The `Template` instance to use as a template for
+        prompt, defaults to None
+        :type prompt: Optional[Template], optional
         :param model: Controls the OpenAI model used, defaults to None
         :param model: Controls the OpenAI model used, defaults to None
         :type model: Optional[str], optional
         :type model: Optional[str], optional
         :param temperature:  Controls the randomness of the model's output.
         :param temperature:  Controls the randomness of the model's output.
@@ -106,8 +111,16 @@ class BaseLlmConfig(BaseConfig):
         contain $context and $query (and optionally $history)
         contain $context and $query (and optionally $history)
         :raises ValueError: Stream is not boolean
         :raises ValueError: Stream is not boolean
         """
         """
-        if template is None:
-            template = DEFAULT_PROMPT_TEMPLATE
+        if template is not None:
+            logging.warning(
+                "The `template` argument is deprecated and will be removed in a future version. "
+                + "Please use `prompt` instead."
+            )
+            if prompt is None:
+                prompt = template
+
+        if prompt is None:
+            prompt = DEFAULT_PROMPT_TEMPLATE
 
 
         self.number_documents = number_documents
         self.number_documents = number_documents
         self.temperature = temperature
         self.temperature = temperature
@@ -120,37 +133,37 @@ class BaseLlmConfig(BaseConfig):
         self.callbacks = callbacks
         self.callbacks = callbacks
         self.api_key = api_key
         self.api_key = api_key
 
 
-        if type(template) is str:
-            template = Template(template)
+        if type(prompt) is str:
+            prompt = Template(prompt)
 
 
-        if self.validate_template(template):
-            self.template = template
+        if self.validate_prompt(prompt):
+            self.prompt = prompt
         else:
         else:
-            raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
+            raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
 
 
         if not isinstance(stream, bool):
         if not isinstance(stream, bool):
             raise ValueError("`stream` should be bool")
             raise ValueError("`stream` should be bool")
         self.stream = stream
         self.stream = stream
         self.where = where
         self.where = where
 
 
-    def validate_template(self, template: Template) -> bool:
+    def validate_prompt(self, prompt: Template) -> bool:
         """
         """
-        validate the template
+        validate the prompt
 
 
-        :param template: the template to validate
-        :type template: Template
+        :param prompt: the prompt to validate
+        :type prompt: Template
         :return: valid (true) or invalid (false)
         :return: valid (true) or invalid (false)
         :rtype: bool
         :rtype: bool
         """
         """
-        return re.search(query_re, template.template) and re.search(context_re, template.template)
+        return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
 
 
-    def _validate_template_history(self, template: Template) -> bool:
+    def _validate_prompt_history(self, prompt: Template) -> bool:
         """
         """
-        validate the template with history
+        validate the prompt with history
 
 
-        :param template: the template to validate
-        :type template: Template
+        :param prompt: the prompt to validate
+        :type prompt: Template
         :return: valid (true) or invalid (false)
         :return: valid (true) or invalid (false)
         :rtype: bool
         :rtype: bool
         """
         """
-        return re.search(history_re, template.template)
+        return re.search(history_re, prompt.template)

+ 14 - 14
embedchain/llm/base.py

@@ -74,19 +74,19 @@ class BaseLlm(JSONSerializable):
         if web_search_result:
         if web_search_result:
             context_string = self._append_search_and_context(context_string, web_search_result)
             context_string = self._append_search_and_context(context_string, web_search_result)
 
 
-        template_contains_history = self.config._validate_template_history(self.config.template)
-        if template_contains_history:
-            # Template contains history
+        prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
+        if prompt_contains_history:
+            # Prompt contains history
             # If there is no history yet, we insert `- no history -`
             # If there is no history yet, we insert `- no history -`
-            prompt = self.config.template.substitute(
+            prompt = self.config.prompt.substitute(
                 context=context_string, query=input_query, history=self.history or "- no history -"
                 context=context_string, query=input_query, history=self.history or "- no history -"
             )
             )
-        elif self.history and not template_contains_history:
-            # History is present, but not included in the template.
-            # check if it's the default template without history
+        elif self.history and not prompt_contains_history:
+            # History is present, but not included in the prompt.
+            # check if it's the default prompt without history
             if (
             if (
-                not self.config._validate_template_history(self.config.template)
-                and self.config.template.template == DEFAULT_PROMPT
+                not self.config._validate_prompt_history(self.config.prompt)
+                and self.config.prompt.template == DEFAULT_PROMPT
             ):
             ):
                 # swap in the template with history
                 # swap in the template with history
                 prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
                 prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
@@ -95,12 +95,12 @@ class BaseLlm(JSONSerializable):
             else:
             else:
                 # If we can't swap in the default, we still proceed but tell users that the history is ignored.
                 # If we can't swap in the default, we still proceed but tell users that the history is ignored.
                 logging.warning(
                 logging.warning(
-                    "Your bot contains a history, but template does not include `$history` key. History is ignored."
+                    "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
                 )
                 )
-                prompt = self.config.template.substitute(context=context_string, query=input_query)
+                prompt = self.config.prompt.substitute(context=context_string, query=input_query)
         else:
         else:
             # basic use case, no history.
             # basic use case, no history.
-            prompt = self.config.template.substitute(context=context_string, query=input_query)
+            prompt = self.config.prompt.substitute(context=context_string, query=input_query)
         return prompt
         return prompt
 
 
     def _append_search_and_context(self, context: str, web_search_result: str) -> str:
     def _append_search_and_context(self, context: str, web_search_result: str) -> str:
@@ -191,7 +191,7 @@ class BaseLlm(JSONSerializable):
                 return contexts
                 return contexts
 
 
             if self.is_docs_site_instance:
             if self.is_docs_site_instance:
-                self.config.template = DOCS_SITE_PROMPT_TEMPLATE
+                self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
                 self.config.number_documents = 5
                 self.config.number_documents = 5
             k = {}
             k = {}
             if self.online:
             if self.online:
@@ -242,7 +242,7 @@ class BaseLlm(JSONSerializable):
                 self.config = config
                 self.config = config
 
 
             if self.is_docs_site_instance:
             if self.is_docs_site_instance:
-                self.config.template = DOCS_SITE_PROMPT_TEMPLATE
+                self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
                 self.config.number_documents = 5
                 self.config.number_documents = 5
             k = {}
             k = {}
             if self.online:
             if self.online:

+ 1 - 0
embedchain/utils.py

@@ -396,6 +396,7 @@ def validate_config(config_data):
                     Optional("top_p"): Or(float, int),
                     Optional("top_p"): Or(float, int),
                     Optional("stream"): bool,
                     Optional("stream"): bool,
                     Optional("template"): str,
                     Optional("template"): str,
+                    Optional("prompt"): str,
                     Optional("system_prompt"): str,
                     Optional("system_prompt"): str,
                     Optional("deployment_name"): str,
                     Optional("deployment_name"): str,
                     Optional("where"): dict,
                     Optional("where"): dict,

+ 1 - 1
tests/helper_classes/test_json_serializable.py

@@ -76,4 +76,4 @@ class TestJsonSerializable(unittest.TestCase):
         config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
         config = BaseLlmConfig(template=Template("My custom template with $query, $context and $history."))
         s = config.serialize()
         s = config.serialize()
         new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
         new_config: BaseLlmConfig = BaseLlmConfig.deserialize(s)
-        self.assertEqual(config.template.template, new_config.template.template)
+        self.assertEqual(config.prompt.template, new_config.prompt.template)

+ 1 - 1
tests/llm/test_base_llm.py

@@ -25,7 +25,7 @@ def test_is_stream_bool():
 def test_template_string_gets_converted_to_Template_instance():
 def test_template_string_gets_converted_to_Template_instance():
     config = BaseLlmConfig(template="test value $query $context")
     config = BaseLlmConfig(template="test value $query $context")
     llm = BaseLlm(config=config)
     llm = BaseLlm(config=config)
-    assert isinstance(llm.config.template, Template)
+    assert isinstance(llm.config.prompt, Template)
 
 
 
 
 def test_is_get_llm_model_answer_implemented():
 def test_is_get_llm_model_answer_implemented():

+ 2 - 2
tests/llm/test_generate_prompt.py

@@ -53,7 +53,7 @@ class TestGeneratePrompt(unittest.TestCase):
         result = self.app.llm.generate_prompt(input_query, contexts)
         result = self.app.llm.generate_prompt(input_query, contexts)
 
 
         # Assert
         # Assert
-        expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
+        expected_result = config.prompt.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
         self.assertEqual(result, expected_result)
         self.assertEqual(result, expected_result)
 
 
     def test_generate_prompt_with_history(self):
     def test_generate_prompt_with_history(self):
@@ -61,7 +61,7 @@ class TestGeneratePrompt(unittest.TestCase):
         Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
         Test the 'generate_prompt' method with BaseLlmConfig containing a history attribute.
         """
         """
         config = BaseLlmConfig()
         config = BaseLlmConfig()
-        config.template = Template("Context: $context | Query: $query | History: $history")
+        config.prompt = Template("Context: $context | Query: $query | History: $history")
         self.app.llm.config = config
         self.app.llm.config = config
         self.app.llm.set_history(["Past context 1", "Past context 2"])
         self.app.llm.set_history(["Past context 1", "Past context 2"])
         prompt = self.app.llm.generate_prompt("Test query", ["Test context"])
         prompt = self.app.llm.generate_prompt("Test query", ["Test context"])