|
@@ -1,14 +1,15 @@
|
|
|
import logging
|
|
|
from typing import Any, Dict, Generator, List, Optional
|
|
|
|
|
|
-from langchain.memory import ConversationBufferMemory
|
|
|
-from langchain.schema import BaseMessage
|
|
|
+from langchain.schema import BaseMessage as LCBaseMessage
|
|
|
|
|
|
from embedchain.config import BaseLlmConfig
|
|
|
from embedchain.config.llm.base import (DEFAULT_PROMPT,
|
|
|
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
|
|
|
DOCS_SITE_PROMPT_TEMPLATE)
|
|
|
from embedchain.helper.json_serializable import JSONSerializable
|
|
|
+from embedchain.memory.base import ECChatMemory
|
|
|
+from embedchain.memory.message import ChatMessage
|
|
|
|
|
|
|
|
|
class BaseLlm(JSONSerializable):
|
|
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
|
|
|
else:
|
|
|
self.config = config
|
|
|
|
|
|
- self.memory = ConversationBufferMemory()
|
|
|
+ self.memory = ECChatMemory()
|
|
|
self.is_docs_site_instance = False
|
|
|
self.online = False
|
|
|
self.history: Any = None
|
|
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
|
|
|
"""
|
|
|
self.history = history
|
|
|
|
|
|
- def update_history(self):
|
|
|
+ def update_history(self, app_id: str):
|
|
|
"""Update class history attribute with history in memory (for chat method)"""
|
|
|
- chat_history = self.memory.load_memory_variables({})["history"]
|
|
|
+ chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
|
|
|
if chat_history:
|
|
|
- self.set_history(chat_history)
|
|
|
+ self.set_history([str(history) for history in chat_history])
|
|
|
+
|
|
|
+ def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
|
|
|
+ chat_message = ChatMessage()
|
|
|
+ chat_message.add_user_message(question, metadata=metadata)
|
|
|
+ chat_message.add_ai_message(answer, metadata=metadata)
|
|
|
+ self.memory.add(app_id=app_id, chat_message=chat_message)
|
|
|
+ self.update_history(app_id=app_id)
|
|
|
|
|
|
def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
|
|
|
"""
|
|
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
|
|
|
for chunk in answer:
|
|
|
streamed_answer = streamed_answer + chunk
|
|
|
yield chunk
|
|
|
- self.memory.chat_memory.add_ai_message(streamed_answer)
|
|
|
logging.info(f"Answer: {streamed_answer}")
|
|
|
|
|
|
def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
|
|
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
|
|
|
if self.online:
|
|
|
k["web_search_result"] = self.access_search_and_get_results(input_query)
|
|
|
|
|
|
- self.update_history()
|
|
|
-
|
|
|
prompt = self.generate_prompt(input_query, contexts, **k)
|
|
|
logging.info(f"Prompt: {prompt}")
|
|
|
|
|
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
|
|
|
|
|
|
answer = self.get_answer_from_llm(prompt)
|
|
|
|
|
|
- self.memory.chat_memory.add_user_message(input_query)
|
|
|
-
|
|
|
if isinstance(answer, str):
|
|
|
- self.memory.chat_memory.add_ai_message(answer)
|
|
|
logging.info(f"Answer: {answer}")
|
|
|
|
|
|
- # NOTE: Adding to history before and after. This could be seen as redundant.
|
|
|
- # If we change it, we have to change the tests (no big deal).
|
|
|
- self.update_history()
|
|
|
-
|
|
|
return answer
|
|
|
else:
|
|
|
# this is a streamed response and needs to be handled differently.
|
|
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
|
|
|
self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
|
|
|
|
|
|
@staticmethod
|
|
|
- def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
|
|
|
+ def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
|
|
|
"""
|
|
|
Construct a list of langchain messages
|
|
|
|