base.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import json
  2. import logging
  3. import sqlite3
  4. import uuid
  5. from typing import Any, Dict, List, Optional
  6. from embedchain.constants import SQLITE_PATH
  7. from embedchain.memory.message import ChatMessage
  8. from embedchain.memory.utils import merge_metadata_dict
  9. CHAT_MESSAGE_CREATE_TABLE_QUERY = """
  10. CREATE TABLE IF NOT EXISTS chat_history (
  11. app_id TEXT,
  12. id TEXT,
  13. question TEXT,
  14. answer TEXT,
  15. metadata TEXT,
  16. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  17. PRIMARY KEY (id, app_id)
  18. )
  19. """
  20. class ECChatMemory:
  21. def __init__(self) -> None:
  22. with sqlite3.connect(SQLITE_PATH, check_same_thread=False) as self.connection:
  23. self.cursor = self.connection.cursor()
  24. self.cursor.execute(CHAT_MESSAGE_CREATE_TABLE_QUERY)
  25. self.connection.commit()
  26. def add(self, app_id, chat_message: ChatMessage) -> Optional[str]:
  27. memory_id = str(uuid.uuid4())
  28. metadata_dict = merge_metadata_dict(chat_message.human_message.metadata, chat_message.ai_message.metadata)
  29. if metadata_dict:
  30. metadata = self._serialize_json(metadata_dict)
  31. ADD_CHAT_MESSAGE_QUERY = """
  32. INSERT INTO chat_history (app_id, id, question, answer, metadata)
  33. VALUES (?, ?, ?, ?, ?)
  34. """
  35. self.cursor.execute(
  36. ADD_CHAT_MESSAGE_QUERY,
  37. (
  38. app_id,
  39. memory_id,
  40. chat_message.human_message.content,
  41. chat_message.ai_message.content,
  42. metadata if metadata_dict else "{}",
  43. ),
  44. )
  45. self.connection.commit()
  46. logging.info(f"Added chat memory to db with id: {memory_id}")
  47. return memory_id
  48. def delete_chat_history(self, app_id: str):
  49. DELETE_CHAT_HISTORY_QUERY = """
  50. DELETE FROM chat_history WHERE app_id=?
  51. """
  52. self.cursor.execute(
  53. DELETE_CHAT_HISTORY_QUERY,
  54. (app_id,),
  55. )
  56. self.connection.commit()
  57. def get_recent_memories(self, app_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
  58. """
  59. Get the most recent num_rounds rounds of conversations
  60. between human and AI, for a given app_id.
  61. """
  62. QUERY = """
  63. SELECT * FROM chat_history
  64. WHERE app_id=?
  65. ORDER BY created_at DESC
  66. LIMIT ?
  67. """
  68. self.cursor.execute(
  69. QUERY,
  70. (app_id, num_rounds),
  71. )
  72. results = self.cursor.fetchall()
  73. history = []
  74. for result in results:
  75. app_id, _, question, answer, metadata, timestamp = result
  76. metadata = self._deserialize_json(metadata=metadata)
  77. # Return list of dict if display_format is True
  78. if display_format:
  79. history.append({"human": question, "ai": answer, "metadata": metadata, "timestamp": timestamp})
  80. else:
  81. memory = ChatMessage()
  82. memory.add_user_message(question, metadata=metadata)
  83. memory.add_ai_message(answer, metadata=metadata)
  84. history.append(memory)
  85. return history
  86. def _serialize_json(self, metadata: Dict[str, Any]):
  87. return json.dumps(metadata)
  88. def _deserialize_json(self, metadata: str):
  89. return json.loads(metadata)
  90. def close_connection(self):
  91. self.connection.close()
  92. def count_history_messages(self, app_id: str):
  93. QUERY = """
  94. SELECT COUNT(*) FROM chat_history
  95. WHERE app_id=?
  96. """
  97. self.cursor.execute(
  98. QUERY,
  99. (app_id,),
  100. )
  101. count = self.cursor.fetchone()[0]
  102. return count