Prechádzať zdrojové kódy

feat: History API (client) (#204)

cachho 2 rokov pred
rodič
commit
65414af098

+ 3 - 0
README.md

@@ -457,6 +457,7 @@ _coming soon_
 |option|description|type|default|
 |---|---|---|---|
 |template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: $query Helpful Answer:")|
+|history|include conversation history from your client or database|any (recommendation: list[str])|None
 |stream|control if response is streamed back to the user|bool|False|
 
 #### **Chat Config**
@@ -465,6 +466,8 @@ All options for query and...
 
 _coming soon_
 
+history is handled automatically, the config option is not supported.
+
 ## Other methods
 
 ### Reset

+ 36 - 4
embedchain/config/ChatConfig.py

@@ -1,14 +1,46 @@
 from embedchain.config.QueryConfig import QueryConfig
+from string import Template
+
+DEFAULT_PROMPT = """
+  You are a chatbot having a conversation with a human. You are given chat history and context.
+  You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know".
+
+  $context
+
+  History: $history
+
+  Query: $query
+
+  Helpful Answer:
+"""
+
+DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
 
 class ChatConfig(QueryConfig):
     """
     Config for the `chat` method, inherits from `QueryConfig`.
     """
-    def __init__(self, stream: bool = False):
+    def __init__(self, template: Template = None, stream: bool = False):
         """
-        Initializes the QueryConfig instance.
+        Initializes the ChatConfig instance.
 
+        :param template: Optional. The `Template` instance to use as a template for prompt.
         :param stream: Optional. Control if response is streamed back to the user
-        :raises ValueError: If the template is not valid as template should contain $context and $query
+        :raises ValueError: If the template is not valid as template should contain $context and $query and $history
         """
-        super().__init__(stream=stream)
+        if template is None:
+            template = DEFAULT_PROMPT_TEMPLATE
+
+        # History is set as 0 to ensure that there is always a history, that way, there don't have to be two templates.
+        # Having two templates would make it complicated because the history is not user controlled.
+        super().__init__(template, history=[0], stream=stream)
+
+    def set_history(self, history):
+        """
+        Chat history is not user provided and not set at initialization time
+        
+        :param history: (string) history to set
+        """
+        self.history = history
+        return
+

+ 55 - 7
embedchain/config/QueryConfig.py

@@ -13,30 +13,78 @@ DEFAULT_PROMPT = """
   Helpful Answer:
 """
 
+DEFAULT_PROMPT_WITH_HISTORY = """
+  Use the following pieces of context to answer the query at the end.
+  If you don't know the answer, just say that you don't know, don't try to make up an answer.
+  I will provide you with our conversation history. 
+
+  $context
+
+  History: $history
+
+  Query: $query
+
+  Helpful Answer:
+"""
+
 DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
+DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
 query_re = re.compile(r"\$\{*query\}*")
 context_re = re.compile(r"\$\{*context\}*")
+history_re = re.compile(r"\$\{*history\}*")
 
 
 class QueryConfig(BaseConfig):
     """
     Config for the `query` method.
     """
-    def __init__(self, template: Template = None, stream: bool = False):
+    def __init__(self, template: Template = None, history = None, stream: bool = False):
         """
         Initializes the QueryConfig instance.
 
         :param template: Optional. The `Template` instance to use as a template for prompt.
+        :param history: Optional. A list of strings to consider as history.
         :param stream: Optional. Control if response is streamed back to the user
-        :raises ValueError: If the template is not valid as template should contain $context and $query
+        :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history).
         """
+        if not history:
+            self.history = None
+        else:
+            if len(history) == 0:
+                self.history = None
+            else:
+                self.history = history
+
         if template is None:
-            template = DEFAULT_PROMPT_TEMPLATE
-        if not (re.search(query_re, template.template) \
-            and re.search(context_re, template.template)):
-            raise ValueError("`template` should have `query` and `context` keys")
-        self.template = template
+            if self.history is None:
+                template = DEFAULT_PROMPT_TEMPLATE
+            else:
+                template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
+
+        if self.validate_template(template):
+            self.template = template
+        else:
+            if self.history is None:
+                raise ValueError("`template` should have `query` and `context` keys")
+            else:
+                raise ValueError("`template` should have `query`, `context` and `history` keys")
 
         if not isinstance(stream, bool):
             raise ValueError("`stream` should be bool")
         self.stream = stream
+                
+
+    def validate_template(self, template: Template):
+        """
+        validate the template
+
+        :param template: the template to validate
+        :return: Boolean, valid (true) or invalid (false)
+        """
+        if self.history is None:
+            return (re.search(query_re, template.template) \
+                and re.search(context_re, template.template))
+        else:
+            return (re.search(query_re, template.template) \
+                and re.search(context_re, template.template)
+                and re.search(history_re, template.template))

+ 17 - 30
embedchain/embedchain.py

@@ -11,6 +11,7 @@ from langchain.memory import ConversationBufferMemory
 from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
 from embedchain.config.QueryConfig import DEFAULT_PROMPT
 from embedchain.data_formatter import DataFormatter
+from string import Template
 
 gpt4all_model = None
 
@@ -144,16 +145,19 @@ class EmbedChain:
             content = ""
         return content
 
-    def generate_prompt(self, input_query, context, template: Template = None):
+    def generate_prompt(self, input_query, context, config: QueryConfig):
         """
         Generates a prompt based on the given query and context, ready to be passed to an LLM
 
         :param input_query: The query to use.
         :param context: Similar documents to the query used as context.
-        :param template: Optional. The `Template` instance to use as a template for prompt.
+        :param config: Optional. The `QueryConfig` instance to use as configuration options.
         :return: The prompt
         """
-        prompt = template.substitute(context = context, query = input_query)
+        if not config.history:
+            prompt = config.template.substitute(context = context, query = input_query)
+        else:
+            prompt = config.template.substitute(context = context, query = input_query, history = config.history)
         return prompt
 
     def get_answer_from_llm(self, prompt, config: ChatConfig):
@@ -181,30 +185,12 @@ class EmbedChain:
         if config is None:
             config = QueryConfig()
         context = self.retrieve_from_database(input_query)
-        prompt = self.generate_prompt(input_query, context, config.template)
+        prompt = self.generate_prompt(input_query, context, config)
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
         logging.info(f"Answer: {answer}")
         return answer
 
-    def generate_chat_prompt(self, input_query, context, chat_history=''):
-        """
-        Generates a prompt based on the given query, context and chat history
-        for chat interface. This is then passed to an LLM.
-
-        :param input_query: The query to use.
-        :param context: Similar documents to the query used as context.
-        :param chat_history: User and bot conversation that happened before.
-        :return: The prompt
-        """
-        prefix_prompt = f"""You are a chatbot having a conversation with a human. You are given chat history and context. You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know"."""
-        chat_history_prompt = f"""\n----\nChat History: {chat_history}\n----"""
-        suffix_prompt = f"""\n####\nContext: {context}\n####\nQuery: {input_query}\nHelpful Answer:"""
-        prompt = prefix_prompt
-        if chat_history:
-            prompt += chat_history_prompt
-        prompt += suffix_prompt
-        return prompt
 
     def chat(self, input_query, config: ChatConfig = None):
         """
@@ -217,18 +203,19 @@ class EmbedChain:
         :param config: Optional. The `ChatConfig` instance to use as configuration options.
         :return: The answer to the query.
         """
-        if config is None:
-            config = ChatConfig()
         context = self.retrieve_from_database(input_query)
         global memory
         chat_history = memory.load_memory_variables({})["history"]
-        prompt = self.generate_chat_prompt(
-            input_query,
-            context,
-            chat_history=chat_history,
-        )
+        
+        if config is None:
+            config = ChatConfig()
+        if chat_history:
+            config.set_history(chat_history)
+            
+        prompt = self.generate_prompt(input_query, context, config)
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
+
         memory.chat_memory.add_user_message(input_query)
         
         if isinstance(answer, str):
@@ -264,7 +251,7 @@ class EmbedChain:
         if config is None:
             config = QueryConfig()
         context = self.retrieve_from_database(input_query)
-        prompt = self.generate_prompt(input_query, context, config.template)
+        prompt = self.generate_prompt(input_query, context, config)
         logging.info(f"Prompt: {prompt}")
         return prompt