Parcourir la source

[bug-fix] fix issue related to bot memory when using multiple bots at the same time (#486)

Deshraj Yadav il y a 2 ans
Parent
commit
4388f6bfc2
1 fichiers modifiés avec 5 ajouts et 7 suppressions
  1. 5 7
      embedchain/embedchain.py

+ 5 - 7
embedchain/embedchain.py

@@ -26,8 +26,6 @@ load_dotenv()
 ABS_PATH = os.getcwd()
 DB_DIR = os.path.join(ABS_PATH, "db")
 
-memory = ConversationBufferMemory()
-
 
 class EmbedChain:
     def __init__(self, config: BaseAppConfig):
@@ -44,6 +42,7 @@ class EmbedChain:
         self.user_asks = []
         self.is_docs_site_instance = False
         self.online = False
+        self.memory = ConversationBufferMemory()
 
         # Send anonymous telemetry
         self.s_id = self.config.id if self.config.id else str(uuid.uuid4())
@@ -362,8 +361,7 @@ class EmbedChain:
             k["web_search_result"] = self.access_search_and_get_results(input_query)
         contexts = self.retrieve_from_database(input_query, config)
 
-        global memory
-        chat_history = memory.load_memory_variables({})["history"]
+        chat_history = self.memory.load_memory_variables({})["history"]
 
         if chat_history:
             config.set_history(chat_history)
@@ -376,14 +374,14 @@ class EmbedChain:
 
         answer = self.get_answer_from_llm(prompt, config)
 
-        memory.chat_memory.add_user_message(input_query)
+        self.memory.chat_memory.add_user_message(input_query)
 
         # Send anonymous telemetry
         thread_telemetry = threading.Thread(target=self._send_telemetry_event, args=("chat",))
         thread_telemetry.start()
 
         if isinstance(answer, str):
-            memory.chat_memory.add_ai_message(answer)
+            self.memory.chat_memory.add_ai_message(answer)
             logging.info(f"Answer: {answer}")
             return answer
         else:
@@ -395,7 +393,7 @@ class EmbedChain:
         for chunk in answer:
             streamed_answer = streamed_answer + chunk
             yield chunk
-        memory.chat_memory.add_ai_message(streamed_answer)
+        self.memory.chat_memory.add_ai_message(streamed_answer)
         logging.info(f"Answer: {streamed_answer}")
 
     def set_collection(self, collection_name):