浏览代码

[Improvement] Use SQLite for chat memory (#910)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 年之前
父节点
当前提交
654fd8d74c

+ 1 - 1
embedchain/client.py

@@ -5,7 +5,7 @@ import uuid
 
 
 import requests
 import requests
 
 
-from embedchain.embedchain import CONFIG_DIR, CONFIG_FILE
+from embedchain.constants import CONFIG_DIR, CONFIG_FILE
 
 
 
 
 class Client:
 class Client:

+ 8 - 0
embedchain/constants.py

@@ -0,0 +1,8 @@
+import os
+from pathlib import Path
+
+ABS_PATH = os.getcwd()
+HOME_DIR = str(Path.home())
+CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
+CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
+SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")

+ 8 - 8
embedchain/embedchain.py

@@ -1,9 +1,7 @@
 import hashlib
 import hashlib
 import json
 import json
 import logging
 import logging
-import os
 import sqlite3
 import sqlite3
-from pathlib import Path
 from typing import Any, Dict, List, Optional, Tuple, Union
 from typing import Any, Dict, List, Optional, Tuple, Union
 
 
 from dotenv import load_dotenv
 from dotenv import load_dotenv
@@ -12,6 +10,7 @@ from langchain.docstore.document import Document
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.apps.base_app_config import BaseAppConfig
 from embedchain.config.apps.base_app_config import BaseAppConfig
+from embedchain.constants import SQLITE_PATH
 from embedchain.data_formatter import DataFormatter
 from embedchain.data_formatter import DataFormatter
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.helper.json_serializable import JSONSerializable
@@ -25,12 +24,6 @@ from embedchain.vectordb.base import BaseVectorDB
 
 
 load_dotenv()
 load_dotenv()
 
 
-ABS_PATH = os.getcwd()
-HOME_DIR = str(Path.home())
-CONFIG_DIR = os.path.join(HOME_DIR, ".embedchain")
-CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
-SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
-
 
 
 class EmbedChain(JSONSerializable):
 class EmbedChain(JSONSerializable):
     def __init__(
     def __init__(
@@ -602,6 +595,9 @@ class EmbedChain(JSONSerializable):
             input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
             input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
         )
         )
 
 
+        # add conversation in memory
+        self.llm.add_history(self.config.id, input_query, answer)
+
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
         self.telemetry.capture(event_name="chat", properties=self._telemetry_props)
 
 
@@ -645,5 +641,9 @@ class EmbedChain(JSONSerializable):
         self.db.reset()
         self.db.reset()
         self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
         self.cursor.execute("DELETE FROM data_sources WHERE pipeline_id = ?", (self.config.id,))
         self.connection.commit()
         self.connection.commit()
+        self.clear_history()
         # Send anonymous telemetry
         # Send anonymous telemetry
         self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
         self.telemetry.capture(event_name="reset", properties=self._telemetry_props)
+
+    def clear_history(self):
+        self.llm.memory.delete_chat_history(app_id=self.config.id)

+ 15 - 17
embedchain/llm/base.py

@@ -1,14 +1,15 @@
 import logging
 import logging
 from typing import Any, Dict, Generator, List, Optional
 from typing import Any, Dict, Generator, List, Optional
 
 
-from langchain.memory import ConversationBufferMemory
-from langchain.schema import BaseMessage
+from langchain.schema import BaseMessage as LCBaseMessage
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
 from embedchain.config.llm.base import (DEFAULT_PROMPT,
 from embedchain.config.llm.base import (DEFAULT_PROMPT,
                                         DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
                                         DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
                                         DOCS_SITE_PROMPT_TEMPLATE)
                                         DOCS_SITE_PROMPT_TEMPLATE)
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.helper.json_serializable import JSONSerializable
+from embedchain.memory.base import ECChatMemory
+from embedchain.memory.message import ChatMessage
 
 
 
 
 class BaseLlm(JSONSerializable):
 class BaseLlm(JSONSerializable):
@@ -23,7 +24,7 @@ class BaseLlm(JSONSerializable):
         else:
         else:
             self.config = config
             self.config = config
 
 
-        self.memory = ConversationBufferMemory()
+        self.memory = ECChatMemory()
         self.is_docs_site_instance = False
         self.is_docs_site_instance = False
         self.online = False
         self.online = False
         self.history: Any = None
         self.history: Any = None
@@ -44,11 +45,18 @@ class BaseLlm(JSONSerializable):
         """
         """
         self.history = history
         self.history = history
 
 
-    def update_history(self):
+    def update_history(self, app_id: str):
         """Update class history attribute with history in memory (for chat method)"""
         """Update class history attribute with history in memory (for chat method)"""
-        chat_history = self.memory.load_memory_variables({})["history"]
+        chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
         if chat_history:
         if chat_history:
-            self.set_history(chat_history)
+            self.set_history([str(history) for history in chat_history])
+
+    def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
+        chat_message = ChatMessage()
+        chat_message.add_user_message(question, metadata=metadata)
+        chat_message.add_ai_message(answer, metadata=metadata)
+        self.memory.add(app_id=app_id, chat_message=chat_message)
+        self.update_history(app_id=app_id)
 
 
     def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
     def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
         """
         """
@@ -165,7 +173,6 @@ class BaseLlm(JSONSerializable):
         for chunk in answer:
         for chunk in answer:
             streamed_answer = streamed_answer + chunk
             streamed_answer = streamed_answer + chunk
             yield chunk
             yield chunk
-        self.memory.chat_memory.add_ai_message(streamed_answer)
         logging.info(f"Answer: {streamed_answer}")
         logging.info(f"Answer: {streamed_answer}")
 
 
     def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
     def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
@@ -257,8 +264,6 @@ class BaseLlm(JSONSerializable):
             if self.online:
             if self.online:
                 k["web_search_result"] = self.access_search_and_get_results(input_query)
                 k["web_search_result"] = self.access_search_and_get_results(input_query)
 
 
-            self.update_history()
-
             prompt = self.generate_prompt(input_query, contexts, **k)
             prompt = self.generate_prompt(input_query, contexts, **k)
             logging.info(f"Prompt: {prompt}")
             logging.info(f"Prompt: {prompt}")
 
 
@@ -267,16 +272,9 @@ class BaseLlm(JSONSerializable):
 
 
             answer = self.get_answer_from_llm(prompt)
             answer = self.get_answer_from_llm(prompt)
 
 
-            self.memory.chat_memory.add_user_message(input_query)
-
             if isinstance(answer, str):
             if isinstance(answer, str):
-                self.memory.chat_memory.add_ai_message(answer)
                 logging.info(f"Answer: {answer}")
                 logging.info(f"Answer: {answer}")
 
 
-                # NOTE: Adding to history before and after. This could be seen as redundant.
-                # If we change it, we have to change the tests (no big deal).
-                self.update_history()
-
                 return answer
                 return answer
             else:
             else:
                 # this is a streamed response and needs to be handled differently.
                 # this is a streamed response and needs to be handled differently.
@@ -287,7 +285,7 @@ class BaseLlm(JSONSerializable):
                 self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
                 self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
 
 
     @staticmethod
     @staticmethod
-    def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
+    def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
         """
         """
         Construct a list of langchain messages
         Construct a list of langchain messages
 
 

+ 0 - 0
embedchain/memory/__init__.py


+ 112 - 0
embedchain/memory/base.py

@@ -0,0 +1,112 @@
+import json
+import logging
+import sqlite3
+import uuid
+from typing import Any, Dict, List, Optional
+
+from embedchain.constants import SQLITE_PATH
+from embedchain.memory.message import ChatMessage
+from embedchain.memory.utils import merge_metadata_dict
+
+CHAT_MESSAGE_CREATE_TABLE_QUERY = """
+            CREATE TABLE IF NOT EXISTS chat_history (
+                app_id TEXT,
+                id TEXT,
+                question TEXT,
+                answer TEXT,
+                metadata TEXT,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+                PRIMARY KEY (id, app_id)
+            )
+            """
+
+
+class ECChatMemory:
+    def __init__(self) -> None:
+        with sqlite3.connect(SQLITE_PATH) as self.connection:
+            self.cursor = self.connection.cursor()
+
+            self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
+            self.connection.commit()
+
+    def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
+        memory_id = str(uuid.uuid4())
+        metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
+        if metadata_dict:
+            metadata = self._serialize_json(metadata_dict)
+        ADD_CHAT_MESSAGE_QUERY = """
+            INSERT INTO chat_history (app_id, id, question, answer, metadata)
+            VALUES (?, ?, ?, ?, ?)
+        """
+        self.cursor.execute(
+            ADD_CHAT_MESSAGE_QUERY,
+            (
+                app_id,
+                memory_id,
+                chat_message.human_message.content,
+                chat_message.ai_message.content,
+                metadata if metadata_dict else "{}",
+            ),
+        )
+        self.connection.commit()
+        logging.info(f"Added chat memory to db with id: {memory_id}")
+        return memory_id
+
+    def delete_chat_history(self, app_id: str):
+        DELETE_CHAT_HISTORY_QUERY = """
+            DELETE FROM chat_history WHERE app_id=?
+        """
+        self.cursor.execute(
+            DELETE_CHAT_HISTORY_QUERY,
+            (app_id,),
+        )
+        self.connection.commit()
+
+    def get_recent_memories(self, app_id, num_rounds=10) -> List[ChatMessage]:
+        """
+        Get the most recent num_rounds rounds of conversations
+        between human and AI, for a given app_id.
+        """
+
+        QUERY = """
+            SELECT * FROM chat_history
+            WHERE app_id=?
+            ORDER BY created_at DESC
+            LIMIT ?
+        """
+        self.cursor.execute(
+            QUERY,
+            (app_id, num_rounds),
+        )
+
+        results = self.cursor.fetchall()
+        history = []
+        for result in results:
+            app_id, id, question, answer, metadata, timestamp = result
+            metadata = self._deserialize_json(metadata=metadata)
+            memory = ChatMessage()
+            memory.add_user_message(question, metadata=metadata)
+            memory.add_ai_message(answer, metadata=metadata)
+            history.append(memory)
+        return history
+
+    def _serialize_json(self, metadata: Dict[str, Any]):
+        return json.dumps(metadata)
+
+    def _deserialize_json(self, metadata: str):
+        return json.loads(metadata)
+
+    def close_connection(self):
+        self.connection.close()
+
+    def count_history_messages(self, app_id: str):
+        QUERY = """
+        SELECT COUNT(*) FROM chat_history
+        WHERE app_id=?
+        """
+        self.cursor.execute(
+            QUERY,
+            (app_id,),
+        )
+        count = self.cursor.fetchone()[0]
+        return count

+ 72 - 0
embedchain/memory/message.py

@@ -0,0 +1,72 @@
+import logging
+from typing import Any, Dict, Optional
+
+from embedchain.helper.json_serializable import JSONSerializable
+
+
+class BaseMessage(JSONSerializable):
+    """
+    The base abstract message class.
+
+    Messages are the inputs and outputs of Models.
+    """
+
+    # The string content of the message.
+    content: str
+
+    # The creator of the message. AI, Human, Bot etc.
+    by: str
+
+    # Any additional info.
+    metadata: Dict[str, Any]
+
+    def __init__(self, content: str, creator: str, metadata: Optional[Dict[str, Any]] = None) -> None:
+        super().__init__()
+        self.content = content
+        self.creator = creator
+        self.metadata = metadata
+
+    @property
+    def type(self) -> str:
+        """Type of the Message, used for serialization."""
+
+    @classmethod
+    def is_lc_serializable(cls) -> bool:
+        """Return whether this class is serializable."""
+        return True
+
+    def __str__(self) -> str:
+        return f"{self.creator}: {self.content}"
+
+
+class ChatMessage(JSONSerializable):
+    """
+    The base abstract chat message class.
+
+    Chat messages are the pair of (question, answer) conversation
+    between human and model.
+    """
+
+    human_message: Optional[BaseMessage] = None
+    ai_message: Optional[BaseMessage] = None
+
+    def add_user_message(self, message: str, metadata: Optional[dict] = None):
+        if self.human_message:
+            logging.info(
+                "Human message already exists in the chat message,\
+                overwritting it with new message."
+            )
+
+        self.human_message = BaseMessage(content=message, creator="human", metadata=metadata)
+
+    def add_ai_message(self, message: str, metadata: Optional[dict] = None):
+        if self.ai_message:
+            logging.info(
+                "AI message already exists in the chat message,\
+                overwritting it with new message."
+            )
+
+        self.ai_message = BaseMessage(content=message, creator="ai", metadata=metadata)
+
+    def __str__(self) -> str:
+        return f"{self.human_message} | {self.ai_message}"

+ 35 - 0
embedchain/memory/utils.py

@@ -0,0 +1,35 @@
+from typing import Any, Dict, Optional
+
+
+def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+    """
+    Merge the metadatas of two BaseMessage types.
+
+    Args:
+        left (Dict[str, Any]): metadata of human message
+        right (Dict[str, Any]): metadata of ai message
+
+    Returns:
+        Dict[str, Any]: combined metadata dict with dedup
+        to be saved in db.
+    """
+    if not left and not right:
+        return None
+    elif not left:
+        return right
+    elif not right:
+        return left
+
+    merged = left.copy()
+    for k, v in right.items():
+        if k not in merged:
+            merged[k] = v
+        elif type(merged[k]) != type(v):
+            raise ValueError(f'additional_kwargs["{k}"] already exists in this message,' " but with a different type.")
+        elif isinstance(merged[k], str):
+            merged[k] += v
+        elif isinstance(merged[k], dict):
+            merged[k] = merge_metadata_dict(merged[k], v)
+        else:
+            raise ValueError(f"Additional kwargs key {k} already exists in this message.")
+    return merged

+ 2 - 3
embedchain/pipeline.py

@@ -10,7 +10,8 @@ import yaml
 
 
 from embedchain import Client
 from embedchain import Client
 from embedchain.config import ChunkerConfig, PipelineConfig
 from embedchain.config import ChunkerConfig, PipelineConfig
-from embedchain.embedchain import CONFIG_DIR, EmbedChain
+from embedchain.constants import SQLITE_PATH
+from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
 from embedchain.factory import EmbedderFactory, LlmFactory, VectorDBFactory
@@ -22,8 +23,6 @@ from embedchain.utils import validate_yaml_config
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.chroma import ChromaDB
 from embedchain.vectordb.chroma import ChromaDB
 
 
-SQLITE_PATH = os.path.join(CONFIG_DIR, "embedchain.db")
-
 
 
 @register_deserializable
 @register_deserializable
 class Pipeline(EmbedChain):
 class Pipeline(EmbedChain):

+ 2 - 1
embedchain/utils.py

@@ -138,7 +138,8 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
     formatted_source = format_source(str(source), 30)
 
 
     if url:
     if url:
-        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import \
+            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")

+ 10 - 0
tests/embedchain/test_embedchain.py

@@ -7,6 +7,7 @@ from embedchain import App
 from embedchain.config import AppConfig, ChromaDbConfig
 from embedchain.config import AppConfig, ChromaDbConfig
 from embedchain.embedchain import EmbedChain
 from embedchain.embedchain import EmbedChain
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.base import BaseLlm
+from embedchain.memory.base import ECChatMemory
 
 
 os.environ["OPENAI_API_KEY"] = "test-api-key"
 os.environ["OPENAI_API_KEY"] = "test-api-key"
 
 
@@ -25,6 +26,11 @@ def test_whole_app(app_instance, mocker):
     mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
     mocker.patch.object(BaseLlm, "get_answer_from_llm", return_value=knowledge)
     mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
     mocker.patch.object(BaseLlm, "get_llm_model_answer", return_value=knowledge)
     mocker.patch.object(BaseLlm, "generate_prompt")
     mocker.patch.object(BaseLlm, "generate_prompt")
+    mocker.patch.object(
+        BaseLlm,
+        "add_history",
+    )
+    mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
 
 
     app_instance.add(knowledge, data_type="text")
     app_instance.add(knowledge, data_type="text")
     app_instance.query("What text did I give you?")
     app_instance.query("What text did I give you?")
@@ -41,6 +47,10 @@ def test_add_after_reset(app_instance, mocker):
     chroma_config = {"allow_reset": True}
     chroma_config = {"allow_reset": True}
 
 
     app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
     app_instance = App(config=config, db_config=ChromaDbConfig(**chroma_config))
+
+    # mock delete chat history
+    mocker.patch.object(ECChatMemory, "delete_chat_history", autospec=True)
+
     app_instance.reset()
     app_instance.reset()
 
 
     app_instance.db.client.heartbeat()
     app_instance.db.client.heartbeat()

+ 26 - 18
tests/llm/test_chat.py

@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
 from embedchain import App
 from embedchain import App
 from embedchain.config import AppConfig, BaseLlmConfig
 from embedchain.config import AppConfig, BaseLlmConfig
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.base import BaseLlm
+from embedchain.memory.base import ECChatMemory
+from embedchain.memory.message import ChatMessage
 
 
 
 
 class TestApp(unittest.TestCase):
 class TestApp(unittest.TestCase):
@@ -31,14 +33,14 @@ class TestApp(unittest.TestCase):
         """
         """
         config = AppConfig(collect_metrics=False)
         config = AppConfig(collect_metrics=False)
         app = App(config=config)
         app = App(config=config)
-        first_answer = app.chat("Test query 1")
-        self.assertEqual(first_answer, "Test answer")
-        self.assertEqual(len(app.llm.memory.chat_memory.messages), 2)
-        self.assertEqual(len(app.llm.history.splitlines()), 2)
-        second_answer = app.chat("Test query 2")
-        self.assertEqual(second_answer, "Test answer")
-        self.assertEqual(len(app.llm.memory.chat_memory.messages), 4)
-        self.assertEqual(len(app.llm.history.splitlines()), 4)
+        with patch.object(BaseLlm, "add_history") as mock_history:
+            first_answer = app.chat("Test query 1")
+            self.assertEqual(first_answer, "Test answer")
+            mock_history.assert_called_with(app.config.id, "Test query 1", "Test answer")
+
+            second_answer = app.chat("Test query 2")
+            self.assertEqual(second_answer, "Test answer")
+            mock_history.assert_called_with(app.config.id, "Test query 2", "Test answer")
 
 
     @patch.object(App, "retrieve_from_database", return_value=["Test context"])
     @patch.object(App, "retrieve_from_database", return_value=["Test context"])
     @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
     @patch.object(BaseLlm, "get_answer_from_llm", return_value="Test answer")
@@ -49,16 +51,22 @@ class TestApp(unittest.TestCase):
 
 
         Also tests that a dry run does not change the history
         Also tests that a dry run does not change the history
         """
         """
-        config = AppConfig(collect_metrics=False)
-        app = App(config=config)
-        first_answer = app.chat("Test query 1")
-        self.assertEqual(first_answer, "Test answer")
-        self.assertEqual(len(app.llm.history.splitlines()), 2)
-        history = app.llm.history
-        dry_run = app.chat("Test query 2", dry_run=True)
-        self.assertIn("History:", dry_run)
-        self.assertEqual(history, app.llm.history)
-        self.assertEqual(len(app.llm.history.splitlines()), 2)
+        with patch.object(ECChatMemory, "get_recent_memories") as mock_memory:
+            mock_message = ChatMessage()
+            mock_message.add_user_message("Test query 1")
+            mock_message.add_ai_message("Test answer")
+            mock_memory.return_value = [mock_message]
+
+            config = AppConfig(collect_metrics=False)
+            app = App(config=config)
+            first_answer = app.chat("Test query 1")
+            self.assertEqual(first_answer, "Test answer")
+            self.assertEqual(len(app.llm.history), 1)
+            history = app.llm.history
+            dry_run = app.chat("Test query 2", dry_run=True)
+            self.assertIn("History:", dry_run)
+            self.assertEqual(history, app.llm.history)
+            self.assertEqual(len(app.llm.history), 1)
 
 
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
     @patch("chromadb.api.models.Collection.Collection.add", MagicMock)
     def test_chat_with_where_in_params(self):
     def test_chat_with_where_in_params(self):

+ 67 - 0
tests/memory/test_chat_memory.py

@@ -0,0 +1,67 @@
+import pytest
+
+from embedchain.memory.base import ECChatMemory
+from embedchain.memory.message import ChatMessage
+
+
+# Fixture for creating an instance of ECChatMemory
+@pytest.fixture
+def chat_memory_instance():
+    return ECChatMemory()
+
+
+def test_add_chat_memory(chat_memory_instance):
+    app_id = "test_app"
+    human_message = "Hello, how are you?"
+    ai_message = "I'm fine, thank you!"
+
+    chat_message = ChatMessage()
+    chat_message.add_user_message(human_message)
+    chat_message.add_ai_message(ai_message)
+
+    chat_memory_instance.add(app_id, chat_message)
+
+    assert chat_memory_instance.count_history_messages(app_id) == 1
+    chat_memory_instance.delete_chat_history(app_id)
+
+
+def test_get_recent_memories(chat_memory_instance):
+    app_id = "test_app"
+
+    for i in range(1, 7):
+        human_message = f"Question {i}"
+        ai_message = f"Answer {i}"
+
+        chat_message = ChatMessage()
+        chat_message.add_user_message(human_message)
+        chat_message.add_ai_message(ai_message)
+
+        chat_memory_instance.add(app_id, chat_message)
+
+    recent_memories = chat_memory_instance.get_recent_memories(app_id, num_rounds=5)
+
+    assert len(recent_memories) == 5
+
+
+def test_delete_chat_history(chat_memory_instance):
+    app_id = "test_app"
+
+    for i in range(1, 6):
+        human_message = f"Question {i}"
+        ai_message = f"Answer {i}"
+
+        chat_message = ChatMessage()
+        chat_message.add_user_message(human_message)
+        chat_message.add_ai_message(ai_message)
+
+        chat_memory_instance.add(app_id, chat_message)
+
+    chat_memory_instance.delete_chat_history(app_id)
+
+    assert chat_memory_instance.count_history_messages(app_id) == 0
+
+
+@pytest.fixture
+def close_connection(chat_memory_instance):
+    yield
+    chat_memory_instance.close_connection()

+ 37 - 0
tests/memory/test_memory_messages.py

@@ -0,0 +1,37 @@
+from embedchain.memory.message import BaseMessage, ChatMessage
+
+
+def test_ec_base_message():
+    content = "Hello, how are you?"
+    creator = "human"
+    metadata = {"key": "value"}
+
+    message = BaseMessage(content=content, creator=creator, metadata=metadata)
+
+    assert message.content == content
+    assert message.creator == creator
+    assert message.metadata == metadata
+    assert message.type is None
+    assert message.is_lc_serializable() is True
+    assert str(message) == f"{creator}: {content}"
+
+
+def test_ec_base_chat_message():
+    human_message_content = "Hello, how are you?"
+    ai_message_content = "I'm fine, thank you!"
+    human_metadata = {"user": "John"}
+    ai_metadata = {"response_time": 0.5}
+
+    chat_message = ChatMessage()
+    chat_message.add_user_message(human_message_content, metadata=human_metadata)
+    chat_message.add_ai_message(ai_message_content, metadata=ai_metadata)
+
+    assert chat_message.human_message.content == human_message_content
+    assert chat_message.human_message.creator == "human"
+    assert chat_message.human_message.metadata == human_metadata
+
+    assert chat_message.ai_message.content == ai_message_content
+    assert chat_message.ai_message.creator == "ai"
+    assert chat_message.ai_message.metadata == ai_metadata
+
+    assert str(chat_message) == f"human: {human_message_content} | ai: {ai_message_content}"

+ 1 - 0
tests/test_utils.py

@@ -1,4 +1,5 @@
 import yaml
 import yaml
+
 from embedchain.utils import validate_yaml_config
 from embedchain.utils import validate_yaml_config
 
 
 CONFIG_YAMLS = [
 CONFIG_YAMLS = [