123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- import json
- import logging
- import sqlite3
- import uuid
- from typing import Any, Optional
- from embedchain.constants import SQLITE_PATH
- from embedchain.memory.message import ChatMessage
- from embedchain.memory.utils import merge_metadata_dict
- CHAT_MESSAGE_CREATE_TABLE_QUERY = """
- CREATE TABLE IF NOT EXISTS ec_chat_history (
- app_id TEXT,
- id TEXT,
- session_id TEXT,
- question TEXT,
- answer TEXT,
- metadata TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- PRIMARY KEY (id, app_id, session_id)
- )
- """
- class ChatHistory:
- def __init__(self) -> None:
- with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
- self.cursor = self.connection.cursor()
- self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
- self.connection.commit()
- def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
- memory_id = str(uuid.uuid4())
- metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
- if metadata_dict:
- metadata = self._serialize_json(metadata_dict)
- ADD_CHAT_MESSAGE_QUERY = """
- INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
- VALUES (?, ?, ?, ?, ?, ?)
- """
- self.cursor.execute(
- ADD_CHAT_MESSAGE_QUERY,
- (
- app_id,
- memory_id,
- session_id,
- chat_message.human_message.content,
- chat_message.ai_message.content,
- metadata if metadata_dict else "{}",
- ),
- )
- self.connection.commit()
- logging.info(f"Added chat memory to db with id: {memory_id}")
- return memory_id
- def delete(self, app_id: str, session_id: Optional[str] = None):
- """
- Delete all chat history for a given app_id and session_id.
- This is useful for deleting chat history for a given user.
- :param app_id: The app_id to delete chat history for
- :param session_id: The session_id to delete chat history for
- :return: None
- """
- if session_id:
- DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
- params = (app_id, session_id)
- else:
- DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=?"
- params = (app_id,)
- self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, params)
- self.connection.commit()
- def get(
- self, app_id, session_id: str = "default", num_rounds=10, fetch_all: bool = False, display_format=False
- ) -> list[ChatMessage]:
- """
- Get the chat history for a given app_id.
- param: app_id - The app_id to get chat history
- param: session_id (optional) - The session_id to get chat history. Defaults to "default"
- param: num_rounds (optional) - The number of rounds to get chat history. Defaults to 10
- param: fetch_all (optional) - Whether to fetch all chat history or not. Defaults to False
- param: display_format (optional) - Whether to return the chat history in display format. Defaults to False
- """
- base_query = """
- SELECT * FROM ec_chat_history
- WHERE app_id=?
- """
- if fetch_all:
- additional_query = "ORDER BY created_at ASC"
- params = (app_id,)
- else:
- additional_query = """
- AND session_id=?
- ORDER BY created_at ASC
- LIMIT ?
- """
- params = (app_id, session_id, num_rounds)
- QUERY = base_query + additional_query
- self.cursor.execute(
- QUERY,
- params,
- )
- results = self.cursor.fetchall()
- history = []
- for result in results:
- app_id, _, session_id, question, answer, metadata, timestamp = result
- metadata = self._deserialize_json(metadata=metadata)
- # Return list of dict if display_format is True
- if display_format:
- history.append(
- {
- "session_id": session_id,
- "human": question,
- "ai": answer,
- "metadata": metadata,
- "timestamp": timestamp,
- }
- )
- else:
- memory = ChatMessage()
- memory.add_user_message(question, metadata=metadata)
- memory.add_ai_message(answer, metadata=metadata)
- history.append(memory)
- return history
- def count(self, app_id: str, session_id: Optional[str] = None):
- """
- Count the number of chat messages for a given app_id and session_id.
- :param app_id: The app_id to count chat history for
- :param session_id: The session_id to count chat history for
- :return: The number of chat messages for a given app_id and session_id
- """
- if session_id:
- QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
- params = (app_id, session_id)
- else:
- QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=?"
- params = (app_id,)
- self.cursor.execute(QUERY, params)
- count = self.cursor.fetchone()[0]
- return count
- @staticmethod
- def _serialize_json(metadata: dict[str, Any]):
- return json.dumps(metadata)
- @staticmethod
- def _deserialize_json(metadata: str):
- return json.loads(metadata)
- def close_connection(self):
- self.connection.close()
|