message.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import logging
  2. from typing import Any, Dict, Optional
  3. from embedchain.helpers.json_serializable import JSONSerializable
  4. class BaseMessage(JSONSerializable):
  5. """
  6. The base abstract message class.
  7. Messages are the inputs and outputs of Models.
  8. """
  9. # The string content of the message.
  10. content: str
  11. # The created_by of the message. AI, Human, Bot etc.
  12. created_by: str
  13. # Any additional info.
  14. metadata: Dict[str, Any]
  15. def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None:
  16. super().__init__()
  17. self.content = content
  18. self.created_by = created_by
  19. self.metadata = metadata
  20. @property
  21. def type(self) -> str:
  22. """Type of the Message, used for serialization."""
  23. @classmethod
  24. def is_lc_serializable(cls) -> bool:
  25. """Return whether this class is serializable."""
  26. return True
  27. def __str__(self) -> str:
  28. return f"{self.created_by}: {self.content}"
  29. class ChatMessage(JSONSerializable):
  30. """
  31. The base abstract chat message class.
  32. Chat messages are the pair of (question, answer) conversation
  33. between human and model.
  34. """
  35. human_message: Optional[BaseMessage] = None
  36. ai_message: Optional[BaseMessage] = None
  37. def add_user_message(self, message: str, metadata: Optional[dict] = None):
  38. if self.human_message:
  39. logging.info(
  40. "Human message already exists in the chat message,\
  41. overwritting it with new message."
  42. )
  43. self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
  44. def add_ai_message(self, message: str, metadata: Optional[dict] = None):
  45. if self.ai_message:
  46. logging.info(
  47. "AI message already exists in the chat message,\
  48. overwritting it with new message."
  49. )
  50. self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)
  51. def __str__(self) -> str:
  52. return f"{self.human_message}\n{self.ai_message}"