base.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import json
  2. import logging
  3. import sqlite3
  4. import uuid
  5. from typing import Any, Dict, List, Optional
  6. from embedchain.constants import SQLITE_PATH
  7. from embedchain.memory.message import ChatMessage
  8. from embedchain.memory.utils import merge_metadata_dict
  9. CHAT_MESSAGE_CREATE_TABLE_QUERY = """
  10. CREATE TABLE IF NOT EXISTS ec_chat_history (
  11. app_id TEXT,
  12. id TEXT,
  13. session_id TEXT,
  14. question TEXT,
  15. answer TEXT,
  16. metadata TEXT,
  17. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  18. PRIMARY KEY (id, app_id, session_id)
  19. )
  20. """
  21. class ChatHistory:
  22. def __init__(self) -> None:
  23. with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
  24. self.cursor = self.connection.cursor()
  25. self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
  26. self.connection.commit()
  27. def add(self, app_id, session_id, chat_message: ChatMessage) -> Optional[str]:
  28. memory_id = str(uuid.uuid4())
  29. metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
  30. if metadata_dict:
  31. metadata = self._serialize_json(metadata_dict)
  32. ADD_CHAT_MESSAGE_QUERY = """
  33. INSERT INTO ec_chat_history (app_id, id, session_id, question, answer, metadata)
  34. VALUES (?, ?, ?, ?, ?, ?)
  35. """
  36. self.cursor.execute(
  37. ADD_CHAT_MESSAGE_QUERY,
  38. (
  39. app_id,
  40. memory_id,
  41. session_id,
  42. chat_message.human_message.content,
  43. chat_message.ai_message.content,
  44. metadata if metadata_dict else "{}",
  45. ),
  46. )
  47. self.connection.commit()
  48. logging.info(f"Added chat memory to db with id: {memory_id}")
  49. return memory_id
  50. def delete(self, app_id: str, session_id: str):
  51. """
  52. Delete all chat history for a given app_id and session_id.
  53. This is useful for deleting chat history for a given user.
  54. :param app_id: The app_id to delete chat history for
  55. :param session_id: The session_id to delete chat history for
  56. :return: None
  57. """
  58. DELETE_CHAT_HISTORY_QUERY = "DELETE FROM ec_chat_history WHERE app_id=? AND session_id=?"
  59. self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
  60. self.connection.commit()
  61. def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
  62. """
  63. Get the most recent num_rounds rounds of conversations
  64. between human and AI, for a given app_id.
  65. """
  66. QUERY = """
  67. SELECT * FROM ec_chat_history
  68. WHERE app_id=? AND session_id=?
  69. ORDER BY created_at DESC
  70. LIMIT ?
  71. """
  72. self.cursor.execute(
  73. QUERY,
  74. (app_id, session_id, num_rounds),
  75. )
  76. results = self.cursor.fetchall()
  77. history = []
  78. for result in results:
  79. app_id, _, session_id, question, answer, metadata, timestamp = result
  80. metadata = self._deserialize_json(metadata=metadata)
  81. # Return list of dict if display_format is True
  82. if display_format:
  83. history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
  84. else:
  85. memory = ChatMessage()
  86. memory.add_user_message(question, metadata=metadata)
  87. memory.add_ai_message(answer, metadata=metadata)
  88. history.append(memory)
  89. return history
  90. def count(self, app_id: str, session_id: str):
  91. """
  92. Count the number of chat messages for a given app_id and session_id.
  93. :param app_id: The app_id to count chat history for
  94. :param session_id: The session_id to count chat history for
  95. :return: The number of chat messages for a given app_id and session_id
  96. """
  97. QUERY = "SELECT COUNT(*) FROM ec_chat_history WHERE app_id=? AND session_id=?"
  98. self.cursor.execute(QUERY, (app_id, session_id))
  99. count = self.cursor.fetchone()[0]
  100. return count
  101. def _serialize_json(self, metadata: Dict[str, Any]):
  102. return json.dumps(metadata)
  103. def _deserialize_json(self, metadata: str):
  104. return json.loads(metadata)
  105. def close_connection(self):
  106. self.connection.close()