Explorar el Código

[Improvement] Add support for reloading history for an existing app (#930)

Deshraj Yadav hace 1 año
padre
commit
17129e2eaa

+ 12 - 2
embedchain/embedchain.py

@@ -74,6 +74,9 @@ class EmbedChain(JSONSerializable):
         if system_prompt:
             self.llm.config.system_prompt = system_prompt
 
+        # Fetch the history from the database if exists
+        self.llm.update_history(app_id=self.config.id)
+
         # Attributes that aren't subclass related.
         self.user_asks = []
 
@@ -641,9 +644,16 @@ class EmbedChain(JSONSerializable):
         self.db.reset()
         self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
         self.connection.commit()
-        self.clear_history()
+        self.delete_history()
         # Send anonymous telemetry
         self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
 
-    def clear_history(self):
+    def get_history(self, num_rounds: int = 10, display_format: bool = True):
+        return self.llm.memory.get_recent_memories(
+            app_id=self.config.id,
+            num_rounds=num_rounds,
+            display_format=display_format,
+        )
+
+    def delete_history(self):
         self.llm.memory.delete_chat_history(app_id=self.config.id)

+ 10 - 6
embedchain/memory/base.py

@@ -62,7 +62,7 @@ class ECChatMemory:
         )
         self.connection.commit()
 
-    def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
+    def get_recent_memories(self, app_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
         """
         Get the most recent num_rounds rounds of conversations
         between human and AI, for a given app_id.
@@ -82,12 +82,16 @@ class ECChatMemory:
         results = self.cursor.fetchall()
         history = []
         for result in results:
-            app_id, id, question, answer, metadata, timestamp = result
+            app_id, _, question, answer, metadata, timestamp = result
             metadata = self._deserialize_json(metadata=metadata)
-            memory = ChatMessage()
-            memory.add_user_message(question, metadata=metadata)
-            memory.add_ai_message(answer, metadata=metadata)
-            history.append(memory)
+            # Return list of dict if display_format is True
+            if display_format:
+                history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
+            else:
+                memory = ChatMessage()
+                memory.add_user_message(question, metadata=metadata)
+                memory.add_ai_message(answer, metadata=metadata)
+                history.append(memory)
         return history
 
     def _serialize_json(self, metadata: Dict[str, Any]):

+ 1 - 1
embedchain/memory/message.py

@@ -69,4 +69,4 @@ class ChatMessage(JSONSerializable):
         self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
 
     def __str__(self) -> str:
-        return f"{self.human_message} | {self.ai_message}"
+        return f"{self.human_message}\n{self.ai_message}"

+ 1 - 3
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.3"
+version = "0.1.4"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
@@ -191,6 +191,4 @@ postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
 
 [tool.poetry.group.docs.dependencies]
 
-
-
 [tool.poetry.scripts]

+ 1 - 1
tests/memory/test_memory_messages.py

@@ -34,4 +34,4 @@ def test_ec_base_chat_message():
     assert chat_message.ai_message.creator == "ai"
     assert chat_message.ai_message.metadata == ai_metadata
 
-    assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}"
+    assert str(chat_message) == f"human: {human_message_content}\nai: {ai_message_content}"