Forráskód Böngészése

[Improvement] update LLM memory get function (#1162)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 éve
szülő
commit
c020e65a50

+ 5 - 2
embedchain/embedchain.py

@@ -667,8 +667,11 @@ class EmbedChain(JSONSerializable):
         # Send anonymous telemetry
         self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
 
-    def get_history(self, num_rounds: int = 10, display_format: bool = True):
-        return self.llm.memory.get(app_id=self.config.id, num_rounds=num_rounds, display_format=display_format)
+    def get_history(self, num_rounds: int = 10, display_format: bool = True, session_id: Optional[str] = "default"):
+        history = self.llm.memory.get(
+            app_id=self.config.id, session_id=session_id, num_rounds=num_rounds, display_format=display_format
+        )
+        return history
 
     def delete_session_chat_history(self, session_id: str = "default"):
         self.llm.memory.delete(app_id=self.config.id, session_id=session_id)

+ 36 - 9
embedchain/memory/base.py

@@ -73,21 +73,40 @@ class ChatHistory:
         self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
         self.connection.commit()
 
-    def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
+    def get(
+        self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
+    ) -> list[ChatMessage]:
         """
-        Get the most recent num_rounds rounds of conversations
-        between human and AI, for a given app_id.
+        Get the chat history for a given app_id.
+
+        param: app_id - The app_id to get chat history
+        param: session_id (optional) - The session_id to get chat history. Defaults to "default"
+        param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
+        param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
+        param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
         """
 
-        QUERY = """
+        base_query = """
             SELECT * FROM ec_chat_history
-            WHERE app_id=? AND session_id=?
-            ORDER BY created_at DESC
-            LIMIT ?
+            WHERE app_id=?
         """
+
+        if fetch_all:
+            additional_query = "ORDER BY created_at DESC"
+            params = (app_id,)
+        else:
+            additional_query = """
+                AND session_id=?
+                ORDER BY created_at DESC
+                LIMIT ?
+            """
+            params = (app_id, session_id, num_rounds)
+
+        QUERY = base_query + additional_query
+
         self.cursor.execute(
             QUERY,
-            (app_id, session_id, num_rounds),
+            params,
         )
 
         results = self.cursor.fetchall()
@@ -97,7 +116,15 @@ class ChatHistory:
             metadata = self._deserialize_json(metadata=metadata)
             # Return list of dict if display_format is True
             if display_format:
-                history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
+                history.append(
+                    {
+                        "session_id": session_id,
+                        "human": question,
+                        "ai": answer,
+                        "metadata": metadata,
+                        "timestamp": timestamp,
+                    }
+                )
             else:
                 memory = ChatMessage()
                 memory.add_user_message(question, metadata=metadata)

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.62"
+version = "0.1.63"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 4 - 0
tests/memory/test_chat_memory.py

@@ -44,6 +44,10 @@ def test_get(chat_memory_instance):
 
     assert len(recent_memories) == 5
 
+    all_memories = chat_memory_instance.get(app_id, fetch_all=True)
+
+    assert len(all_memories) == 6
+
 
 def test_delete_chat_history(chat_memory_instance):
     app_id = "test_app"