浏览代码

[Bug fix] Fix history sequence in prompt (#1254)

Deshraj Yadav 1 年之前
父节点
当前提交
2f285ea00a
共有 5 个文件被更改,包括 22 次插入19 次删除
  1. 2 1
      embedchain/config/llm/base.py
  2. 11 7
      embedchain/llm/base.py
  3. 6 8
      embedchain/llm/openai.py
  4. 2 2
      embedchain/memory/base.py
  5. 1 1
      pyproject.toml

+ 2 - 1
embedchain/config/llm/base.py

@@ -24,7 +24,8 @@ DEFAULT_PROMPT_WITH_HISTORY = """
 
 
   $context
   $context
 
 
-  History: $history
+  History:
+  $history
 
 
   Query: $query
   Query: $query
 
 

+ 11 - 7
embedchain/llm/base.py

@@ -5,9 +5,7 @@ from typing import Any, Optional
 from langchain.schema import BaseMessage as LCBaseMessage
 from langchain.schema import BaseMessage as LCBaseMessage
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.config.llm.base import (DEFAULT_PROMPT,
-                                        DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
-                                        DOCS_SITE_PROMPT_TEMPLATE)
+from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.memory.base import ChatHistory
 from embedchain.memory.base import ChatHistory
 from embedchain.memory.message import ChatMessage
 from embedchain.memory.message import ChatMessage
@@ -65,6 +63,14 @@ class BaseLlm(JSONSerializable):
         self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
         self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
         self.update_history(app_id=app_id, session_id=session_id)
         self.update_history(app_id=app_id, session_id=session_id)
 
 
+    def _format_history(self) -> str:
+        """Format history to be used in prompt
+
+        :return: Formatted history
+        :rtype: str
+        """
+        return "\n".join(self.history)
+
     def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
     def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
         """
         """
         Generates a prompt based on the given query and context, ready to be
         Generates a prompt based on the given query and context, ready to be
@@ -84,10 +90,8 @@ class BaseLlm(JSONSerializable):
 
 
         prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
         prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
         if prompt_contains_history:
         if prompt_contains_history:
-            # Prompt contains history
-            # If there is no history yet, we insert `- no history -`
             prompt = self.config.prompt.substitute(
             prompt = self.config.prompt.substitute(
-                context=context_string, query=input_query, history=self.history or "- no history -"
+                context=context_string, query=input_query, history=self._format_history() or "No history"
             )
             )
         elif self.history and not prompt_contains_history:
         elif self.history and not prompt_contains_history:
             # History is present, but not included in the prompt.
             # History is present, but not included in the prompt.
@@ -98,7 +102,7 @@ class BaseLlm(JSONSerializable):
             ):
             ):
                 # swap in the template with history
                 # swap in the template with history
                 prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
                 prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
-                    context=context_string, query=input_query, history=self.history
+                    context=context_string, query=input_query, history=self._format_history()
                 )
                 )
             else:
             else:
                 # If we can't swap in the default, we still proceed but tell users that the history is ignored.
                 # If we can't swap in the default, we still proceed but tell users that the history is ignored.

+ 6 - 8
embedchain/llm/openai.py

@@ -35,21 +35,19 @@ class OpenAILlm(BaseLlm):
         if config.top_p:
         if config.top_p:
             kwargs["model_kwargs"]["top_p"] = config.top_p
             kwargs["model_kwargs"]["top_p"] = config.top_p
         if config.stream:
         if config.stream:
-            from langchain.callbacks.streaming_stdout import \
-                StreamingStdOutCallbackHandler
+            from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 
 
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
             callbacks = config.callbacks if config.callbacks else [StreamingStdOutCallbackHandler()]
-            chat = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
+            llm = ChatOpenAI(**kwargs, streaming=config.stream, callbacks=callbacks, api_key=api_key)
         else:
         else:
-            chat = ChatOpenAI(**kwargs, api_key=api_key)
+            llm = ChatOpenAI(**kwargs, api_key=api_key)
 
 
         if self.functions is not None:
         if self.functions is not None:
-            from langchain.chains.openai_functions import \
-                create_openai_fn_runnable
+            from langchain.chains.openai_functions import create_openai_fn_runnable
             from langchain.prompts import ChatPromptTemplate
             from langchain.prompts import ChatPromptTemplate
 
 
             structured_prompt = ChatPromptTemplate.from_messages(messages)
             structured_prompt = ChatPromptTemplate.from_messages(messages)
-            runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=chat)
+            runnable = create_openai_fn_runnable(functions=self.functions, prompt=structured_prompt, llm=llm)
             fn_res = runnable.invoke(
             fn_res = runnable.invoke(
                 {
                 {
                     "input": prompt,
                     "input": prompt,
@@ -57,4 +55,4 @@ class OpenAILlm(BaseLlm):
             )
             )
             messages.append(AIMessage(content=json.dumps(fn_res)))
             messages.append(AIMessage(content=json.dumps(fn_res)))
 
 
-        return chat(messages).content
+        return llm(messages).content

+ 2 - 2
embedchain/memory/base.py

@@ -92,12 +92,12 @@ class ChatHistory:
         """
         """
 
 
         if fetch_all:
         if fetch_all:
-            additional_query = "ORDER BY created_at DESC"
+            additional_query = "ORDER BY created_at ASC"
             params = (app_id,)
             params = (app_id,)
         else:
         else:
             additional_query = """
             additional_query = """
                 AND session_id=?
                 AND session_id=?
-                ORDER BY created_at DESC
+                ORDER BY created_at ASC
                 LIMIT ?
                 LIMIT ?
             """
             """
             params = (app_id, session_id, num_rounds)
             params = (app_id, session_id, num_rounds)

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.1.76"
+version = "0.1.77"
 description = "Simplest open source retrieval(RAG) framework"
 description = "Simplest open source retrieval(RAG) framework"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",