瀏覽代碼

[Improvement] Use SQLite for chat memory (#910)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 年之前
父節點
當前提交
654fd8d74c

+ 1 - 1
embedchain/client.py

@@ -5,7 +5,7 @@ import uuid
 
 import requests
 
-from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
+from embedchain.constants import CONFIG_DIR, CONFIG_FILE
 
 
 class Client:

+ 8 - 0
embedchain/constants.py

@@ -0,0 +1,8 @@
+import os
+from pathlib import Path
+
+ABS_PATH = os.getcwd()
+HOME_DIR = str(Path.home())
+CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
+CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
+SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")

+ 8 - 8
embedchain/embedchain.py

@@ -1,9 +1,7 @@
 import hashlib
 import json
 import logging
-import os
 import sqlite3
-from pathlib import Path
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 from dotenv import load_dotenv
@@ -12,6 +10,7 @@ from langchain.docstore.document import Document
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.apps.base_app_config import BaseAppConfig
+from embedchain.constants import SQLITE_PATH
 from embedchain.data_formatter import DataFormatter
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.helper.json_serializable import JSONSerializable
@@ -25,12 +24,6 @@ from embedchain.vectordb.base import BaseVectorDB
 
 load_dotenv()
 
-ABS_PATH = os.getcwd()
-HOME_DIR = str(Path.home())
-CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
-CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
-SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
-
 
 class EmbedChain(JSONSerializable):
     def __init__(
@@ -602,6 +595,9 @@ class EmbedChain(JSONSerializable):
             input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
         )
 
+        # add conversation in memory
+        self.llm.add_history(self.config.id, input_query, answer)
+
         # Send anonymous telemetry
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
 
@@ -645,5 +641,9 @@ class EmbedChain(JSONSerializable):
         self.db.reset()
         self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
         self.connection.commit()
+        self.clear_history()
         # Send anonymous telemetry
         self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
+
+    def clear_history(self):
+        self.llm.memory.delete_chat_history(app_id=self.config.id)

+ 15 - 17
embedchain/llm/base.py

@@ -1,14 +1,15 @@
 import logging
 from typing import Any, Dict, Generator, List, Optional
 
-from langchain.memory import ConversationBufferMemory
-from langchain.schema import BaseMessage
+from langchain.schema import BaseMessage as LCBaseMessage
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config.llm.base import (DEFAULT_PROMPT,
                                         DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
                                         DOCS_SITE_PROMPT_TEMPLATE)
 from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.memory.base import ECChatMemory
+from embedchain.memory.message import ChatMessage
 
 
 class BaseLlm(JSONSerializable):
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
         else:
             self.config = config
 
-        self.memory = ConversationBufferMemory()
+        self.memory = ECChatMemory()
         self.is_docs_site_instance = False
         self.online = False
         self.history: Any = None
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
         """
         self.history = history
 
-    def update_history(self):
+    def update_history(self, app_id: str):
         """Update class history attribute with history in memory (for chat method)"""
-        chat_history = self.memory.load_memory_variables({})["history"]
+        chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
         if chat_history:
-            self.set_history(chat_history)
+            self.set_history([str(history) for history in chat_history])
+
+    def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
+        chat_message = ChatMessage()
+        chat_message.add_user_message(question, metadata=metadata)
+        chat_message.add_ai_message(answer, metadata=metadata)
+        self.memory.add(app_id=app_id, chat_message=chat_message)
+        self.update_history(app_id=app_id)
 
     def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
         """
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
         for chunk in answer:
             streamed_answer = streamed_answer + chunk
             yield chunk
-        self.memory.chat_memory.add_ai_message(streamed_answer)
         logging.info(f"Answer: {streamed_answer}")
 
     def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
             if self.online:
                 k["web_search_result"] = self.access_search_and_get_results(input_query)
 
-            self.update_history()
-
             prompt = self.generate_prompt(input_query, contexts, **k)
             logging.info(f"Prompt: {prompt}")
 
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
 
             answer = self.get_answer_from_llm(prompt)
 
-            self.memory.chat_memory.add_user_message(input_query)
-
             if isinstance(answer, str):
-                self.memory.chat_memory.add_ai_message(answer)
                 logging.info(f"Answer: {answer}")
 
-                # NOTE: Adding to history before and after. This could be seen as redundant.
-                # If we change it, we have to change the tests (no big deal).
-                self.update_history()
-
                 return answer
             else:
                 # this is a streamed response and needs to be handled differently.
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
                 self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
 
     @staticmethod
-    def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
+    def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
         """
         Construct a list of langchain messages
 

+ 0 - 0
embedchain/memory/__init__.py


+ 112 - 0
embedchain/memory/base.py

@@ -0,0 +1,112 @@
+import json
+import logging
+import sqlite3
+import uuid
+from typing import Any, Dict, List, Optional
+
+from embedchain.constants import SQLITE_PATH
+from embedchain.memory.message import ChatMessage
+from embedchain.memory.utils import merge_metadata_dict
+
+CHAT_MESSAGE_CREATE_TABLE_QUERY = """
+            CREATE TABLE IF NOT EXISTS chat_history (
+                app_id TEXT,
+                id TEXT,
+                question TEXT,
+                answer TEXT,
+                metadata TEXT,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                PRIMARY KEY (id, app_id)
+            )
+            """
+
+
+class ECChatMemory:
+    def __init__(self) -> None:
+        with sqlite3.connect(SQLITE_PATH) as self.connection:
+            self.cursor = self.connection.cursor()
+
+            self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
+            self.connection.commit()
+
+    def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
+        memory_id = str(uuid.uuid4())
+        metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
+        if metadata_dict:
+            metadata = self._serialize_json(metadata_dict)
+        ADD_CHAT_MESSAGE_QUERY = """
+            INSERT INTO chat_history (app_id, id, question, answer, metadata)
+            VALUES (?, ?, ?, ?, ?)
+        """
+        self.cursor.execute(
+            ADD_CHAT_MESSAGE_QUERY,
+            (
+                app_id,
+                memory_id,
+                chat_message.human_message.content,
+                chat_message.ai_message.content,
+                metadata if metadata_dict else "{}",
+            ),
+        )
+        self.connection.commit()
+        logging.info(f"Added chat memory to db with id: {memory_id}")
+        return memory_id
+
+    def delete_chat_history(self, app_id: str):
+        DELETE_CHAT_HISTORY_QUERY = """
+            DELETE FROM chat_history WHERE app_id=?
+        """
+        self.cursor.execute(
+            DELETE_CHAT_HISTORY_QUERY,
+            (app_id,),
+        )
+        self.connection.commit()
+
+    def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
+        """
+        Get the most recent num_rounds rounds of conversations
+        between human and AI, for a given app_id.
+        """
+
+        QUERY = """
+            SELECT * FROM chat_history
+            WHERE app_id=?
+            ORDER BY created_at DESC
+            LIMIT ?
+        """
+        self.cursor.execute(
+            QUERY,
+            (app_id, num_rounds),
+        )
+
+        results = self.cursor.fetchall()
+        history = []
+        for result in results:
+            app_id, id, question, answer, metadata, timestamp = result
+            metadata = self._deserialize_json(metadata=metadata)
+            memory = ChatMessage()
+            memory.add_user_message(question, metadata=metadata)
+            memory.add_ai_message(answer, metadata=metadata)
+            history.append(memory)
+        return history
+
+    def _serialize_json(self, metadata: Dict[str, Any]):
+        return json.dumps(metadata)
+
+    def _deserialize_json(self, metadata: str):
+        return json.loads(metadata)
+
+    def close_connection(self):
+        self.connection.close()
+
+    def count_history_messages(self, app_id: str):
+        QUERY = """
+        SELECT COUNT(*) FROM chat_history
+        WHERE app_id=?
+        """
+        self.cursor.execute(
+            QUERY,
+            (app_id,),
+        )
+        count = self.cursor.fetchone()[0]
+        return count

+ 72 - 0
embedchain/memory/message.py

@@ -0,0 +1,72 @@
+import logging
+from typing import Any, Dict, Optional
+
+from embedchain.helper.json_serializable import JSONSerializable
+
+
+class BaseMessage(JSONSerializable):
+    """
+    The base abstract message class.
+
+    Messages are the inputs and outputs of Models.
+    """
+
+    # The string content of the message.
+    content: str
+
+    # The creator of the message. AI, Human, Bot etc.
+    by: str
+
+    # Any additional info.
+    metadata: Dict[str, Any]
+
+    def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__()
+        self.content = content
+        self.creator = creator
+        self.metadata = metadata
+
+    @property
+    def type(self) -> str:
+        """Type of the Message, used for serialization."""
+
+    @classmethod
+    def is_lc_serializable(cls) -> bool:
+        """Return whether this class is serializable."""
+        return True
+
+    def __str__(self) -> str:
+        return f"{self.creator}: {self.content}"
+
+
+class ChatMessage(JSONSerializable):
+    """
+    The base abstract chat message class.
+
+    Chat messages are the pair of (question, answer) conversation
+    between human and model.
+    """
+
+    human_message: Optional[BaseMessage] = None
+    ai_message: Optional[BaseMessage] = None
+
+    def add_user_message(self, message: str, metadata: Optional[dict] = None):
+        if self.human_message:
+            logging.info(
+                "Human message already exists in the chat message,\
+                overwritting it with new message."
+            )
+
+        self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
+
+    def add_ai_message(self, message: str, metadata: Optional[dict] = None):
+        if self.ai_message:
+            logging.info(
+                "AI message already exists in the chat message,\
+                overwritting it with new message."
+            )
+
+        self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
+
+    def __str__(self) -> str:
+        return f"{self.human_message} | {self.ai_message}"

+ 35 - 0
embedchain/memory/utils.py

@@ -0,0 +1,35 @@
+from typing import Any, Dict, Optional
+
+
+def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+    """
+    Merge the metadatas of two BaseMessage types.
+
+    Args:
+        left (Dict[str, Any]): metadata of human message
+        right (Dict[str, Any]): metadata of ai message
+
+    Returns:
+        Dict[str, Any]: combined metadata dict with dedup
+        to be saved in db.
+    """
+    if not left and not right:
+        return None
+    elif not left:
+        return right
+    elif not right:
+        return left
+
+    merged = left.copy()
+    for k, v in right.items():
+        if k not in merged:
+            merged[k] = v
+        elif type(merged[k]) != type(v):
+            raise ValueError(f'additional_kwargs["{k}"] already exists in this message,' " but with a different type.")
+        elif isinstance(merged[k], str):
+            merged[k] += v
+        elif isinstance(merged[k], dict):
+            merged[k] = merge_metadata_dict(merged[k], v)
+        else:
+            raise ValueError(f"Additional kwargs key {k} already exists in this message.")
+    return merged

+ 2 - 3
embedchain/pipeline.py

@@ -10,7 +10,8 @@ import yaml
 
 from embedchain import Client
 from embedchain.config import ChunkerConfig, PipelineConfig
-from embedchain.embedchain import CONFIG_DIR, EmbedChain
+from embedchain.constants import SQLITE_PATH
+from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
@@ -22,8 +23,6 @@ from embedchain.utils import validate_yaml_config
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.chroma import ChromaDB
 
-SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
-
 
 @register_deserializable
 class Pipeline(EmbedChain):

+ 2 - 1
embedchain/utils.py

@@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
 
     if url:
-        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import \
+            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")

+ 10 - 0
tests/embedchain/test_embedchain.py

@@ -7,6 +7,7 @@ from embedchain import App
 from embedchain.config import AppConfig, ChromaDbConfig
 from embedchain.embedchain import EmbedChain
 from embedchain.llm.base import BaseLlm
+from embedchain.memory.base import ECChatMemory
 
 os.environ["OPENAI_API_KEY"] = "test-api-key"
 
@@ -25,6 +26,11 @@ def test_whole_app(app_instance, mocker):
     mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
     mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
     mocker.patch.object(BaseLlm, "generate_prompt")
+    mocker.patch.object(
+        BaseLlm,
+        "add_history",
+    )
+    mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
 
     app_instance.add(knowledge, data_type="text")
     app_instance.query("What text did I give you?")
@@ -41,6 +47,10 @@ def test_add_after_reset(app_instance, mocker):
     chroma_config = {"allow_reset": True}
 
     app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
+
+    # mock delete chat history
+    mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
+
     app_instance.reset()
 
     app_instance.db.client.heartbeat()

+ 26 - 18
tests/llm/test_chat.py

@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
 from embedchain import App
 from embedchain.config import AppConfig, BaseLlmConfig
 from embedchain.llm.base import BaseLlm
+from embedchain.memory.base import ECChatMemory
+from embedchain.memory.message import ChatMessage
 
 
 class TestApp(unittest.TestCase):
@@ -31,14 +33,14 @@ class TestApp(unittest.TestCase):
         """
         config = AppConfig(collect_metrics=False)
         app = App(config=config)
-        first_answer = app.chat("Test query 1")
-        self.assertEqual(first_answer, "Test answer")
-        self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
-        self.assertEqual(len(app.llm.history.splitlines()), 2)
-        second_answer = app.chat("Test query 2")
-        self.assertEqual(second_answer, "Test answer")
-        self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
-        self.assertEqual(len(app.llm.history.splitlines()), 4)
+        with patch.object(BaseLlm, "add_history") as mock_history:
+            first_answer = app.chat("Test query 1")
+            self.assertEqual(first_answer, "Test answer")
+            mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer")
+
+            second_answer = app.chat("Test query 2")
+            self.assertEqual(second_answer, "Test answer")
+            mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
 
     @patch.object(App, "retrieve_from_database", return_value=["Test context"])
     @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
@@ -49,16 +51,22 @@ class TestApp(unittest.TestCase):
 
         Also tests that a dry run does not change the history
         """
-        config = AppConfig(collect_metrics=False)
-        app = App(config=config)
-        first_answer = app.chat("Test query 1")
-        self.assertEqual(first_answer, "Test answer")
-        self.assertEqual(len(app.llm.history.splitlines()), 2)
-        history = app.llm.history
-        dry_run = app.chat("Test query 2", dry_run=True)
-        self.assertIn("History:", dry_run)
-        self.assertEqual(history, app.llm.history)
-        self.assertEqual(len(app.llm.history.splitlines()), 2)
+        with patch.object(ECChatMemory, "get_recent_memories") as mock_memory:
+            mock_message = ChatMessage()
+            mock_message.add_user_message("Test query 1")
+            mock_message.add_ai_message("Test answer")
+            mock_memory.return_value = [mock_message]
+
+            config = AppConfig(collect_metrics=False)
+            app = App(config=config)
+            first_answer = app.chat("Test query 1")
+            self.assertEqual(first_answer, "Test answer")
+            self.assertEqual(len(app.llm.history), 1)
+            history = app.llm.history
+            dry_run = app.chat("Test query 2", dry_run=True)
+            self.assertIn("History:", dry_run)
+            self.assertEqual(history, app.llm.history)
+            self.assertEqual(len(app.llm.history), 1)
 
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
     def test_chat_with_where_in_params(self):

+ 67 - 0
tests/memory/test_chat_memory.py

@@ -0,0 +1,67 @@
+import pytest
+
+from embedchain.memory.base import ECChatMemory
+from embedchain.memory.message import ChatMessage
+
+
+# Fixture for creating an instance of ECChatMemory
+@pytest.fixture
+def chat_memory_instance():
+    return ECChatMemory()
+
+
+def test_add_chat_memory(chat_memory_instance):
+    app_id = "test_app"
+    human_message = "Hello, how are you?"
+    ai_message = "I'm fine, thank you!"
+
+    chat_message = ChatMessage()
+    chat_message.add_user_message(human_message)
+    chat_message.add_ai_message(ai_message)
+
+    chat_memory_instance.add(app_id, chat_message)
+
+    assert chat_memory_instance.count_history_messages(app_id) == 1
+    chat_memory_instance.delete_chat_history(app_id)
+
+
+def test_get_recent_memories(chat_memory_instance):
+    app_id = "test_app"
+
+    for i in range(1, 7):
+        human_message = f"Question {i}"
+        ai_message = f"Answer {i}"
+
+        chat_message = ChatMessage()
+        chat_message.add_user_message(human_message)
+        chat_message.add_ai_message(ai_message)
+
+        chat_memory_instance.add(app_id, chat_message)
+
+    recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5)
+
+    assert len(recent_memories) == 5
+
+
+def test_delete_chat_history(chat_memory_instance):
+    app_id = "test_app"
+
+    for i in range(1, 6):
+        human_message = f"Question {i}"
+        ai_message = f"Answer {i}"
+
+        chat_message = ChatMessage()
+        chat_message.add_user_message(human_message)
+        chat_message.add_ai_message(ai_message)
+
+        chat_memory_instance.add(app_id, chat_message)
+
+    chat_memory_instance.delete_chat_history(app_id)
+
+    assert chat_memory_instance.count_history_messages(app_id) == 0
+
+
+@pytest.fixture
+def close_connection(chat_memory_instance):
+    yield
+    chat_memory_instance.close_connection()

+ 37 - 0
tests/memory/test_memory_messages.py

@@ -0,0 +1,37 @@
+from embedchain.memory.message import BaseMessage, ChatMessage
+
+
+def test_ec_base_message():
+    content = "Hello, how are you?"
+    creator = "human"
+    metadata = {"key": "value"}
+
+    message = BaseMessage(content=content, creator=creator, metadata=metadata)
+
+    assert message.content == content
+    assert message.creator == creator
+    assert message.metadata == metadata
+    assert message.type is None
+    assert message.is_lc_serializable() is True
+    assert str(message) == f"{creator}: {content}"
+
+
+def test_ec_base_chat_message():
+    human_message_content = "Hello, how are you?"
+    ai_message_content = "I'm fine, thank you!"
+    human_metadata = {"user": "John"}
+    ai_metadata = {"response_time": 0.5}
+
+    chat_message = ChatMessage()
+    chat_message.add_user_message(human_message_content, metadata=human_metadata)
+    chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
+
+    assert chat_message.human_message.content == human_message_content
+    assert chat_message.human_message.creator == "human"
+    assert chat_message.human_message.metadata == human_metadata
+
+    assert chat_message.ai_message.content == ai_message_content
+    assert chat_message.ai_message.creator == "ai"
+    assert chat_message.ai_message.metadata == ai_metadata
+
+    assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}"

+ 1 - 0
tests/test_utils.py

@@ -1,4 +1,5 @@
 import yaml
+
 from embedchain.utils import validate_yaml_config
 
 CONFIG_YAMLS = [