test_chat_memory.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import pytest
  2. from embedchain.memory.base import ChatHistory
  3. from embedchain.memory.message import ChatMessage
  4. # Fixture for creating an instance of ChatHistory
  5. @pytest.fixture
  6. def chat_memory_instance():
  7. return ChatHistory()
  8. def test_add_chat_memory(chat_memory_instance):
  9. app_id = "test_app"
  10. session_id = "test_session"
  11. human_message = "Hello, how are you?"
  12. ai_message = "I'm fine, thank you!"
  13. chat_message = ChatMessage()
  14. chat_message.add_user_message(human_message)
  15. chat_message.add_ai_message(ai_message)
  16. chat_memory_instance.add(app_id, session_id, chat_message)
  17. assert chat_memory_instance.count(app_id, session_id) == 1
  18. chat_memory_instance.delete(app_id, session_id)
  19. def test_get(chat_memory_instance):
  20. app_id = "test_app"
  21. session_id = "test_session"
  22. for i in range(1, 7):
  23. human_message = f"Question {i}"
  24. ai_message = f"Answer {i}"
  25. chat_message = ChatMessage()
  26. chat_message.add_user_message(human_message)
  27. chat_message.add_ai_message(ai_message)
  28. chat_memory_instance.add(app_id, session_id, chat_message)
  29. recent_memories = chat_memory_instance.get(app_id, session_id, num_rounds=5)
  30. assert len(recent_memories) == 5
  31. all_memories = chat_memory_instance.get(app_id, fetch_all=True)
  32. assert len(all_memories) == 6
  33. def test_delete_chat_history(chat_memory_instance):
  34. app_id = "test_app"
  35. session_id = "test_session"
  36. for i in range(1, 6):
  37. human_message = f"Question {i}"
  38. ai_message = f"Answer {i}"
  39. chat_message = ChatMessage()
  40. chat_message.add_user_message(human_message)
  41. chat_message.add_ai_message(ai_message)
  42. chat_memory_instance.add(app_id, session_id, chat_message)
  43. session_id_2 = "test_session_2"
  44. for i in range(1, 6):
  45. human_message = f"Question {i}"
  46. ai_message = f"Answer {i}"
  47. chat_message = ChatMessage()
  48. chat_message.add_user_message(human_message)
  49. chat_message.add_ai_message(ai_message)
  50. chat_memory_instance.add(app_id, session_id_2, chat_message)
  51. chat_memory_instance.delete(app_id, session_id)
  52. assert chat_memory_instance.count(app_id, session_id) == 0
  53. assert chat_memory_instance.count(app_id) == 5
  54. chat_memory_instance.delete(app_id)
  55. assert chat_memory_instance.count(app_id) == 0
  56. @pytest.fixture
  57. def close_connection(chat_memory_instance):
  58. yield
  59. chat_memory_instance.close_connection()