message.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import logging
  2. from typing import Any, Optional
  3. from embedchain.helpers.json_serializable import JSONSerializable
  4. logger = logging.getLogger(__name__)
  5. class BaseMessage(JSONSerializable):
  6. """
  7. The base abstract message class.
  8. Messages are the inputs and outputs of Models.
  9. """
  10. # The string content of the message.
  11. content: str
  12. # The created_by of the message. AI, Human, Bot etc.
  13. created_by: str
  14. # Any additional info.
  15. metadata: dict[str, Any]
  16. def __init__(self, content: str, created_by: str, metadata: Optional[dict[str, Any]] = None) -> None:
  17. super().__init__()
  18. self.content = content
  19. self.created_by = created_by
  20. self.metadata = metadata
  21. @property
  22. def type(self) -> str:
  23. """Type of the Message, used for serialization."""
  24. @classmethod
  25. def is_lc_serializable(cls) -> bool:
  26. """Return whether this class is serializable."""
  27. return True
  28. def __str__(self) -> str:
  29. return f"{self.created_by}: {self.content}"
  30. class ChatMessage(JSONSerializable):
  31. """
  32. The base abstract chat message class.
  33. Chat messages are the pair of (question, answer) conversation
  34. between human and model.
  35. """
  36. human_message: Optional[BaseMessage] = None
  37. ai_message: Optional[BaseMessage] = None
  38. def add_user_message(self, message: str, metadata: Optional[dict] = None):
  39. if self.human_message:
  40. logger.info(
  41. "Human message already exists in the chat message,\
  42. overwriting it with new message."
  43. )
  44. self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
  45. def add_ai_message(self, message: str, metadata: Optional[dict] = None):
  46. if self.ai_message:
  47. logger.info(
  48. "AI message already exists in the chat message,\
  49. overwriting it with new message."
  50. )
  51. self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)
  52. def __str__(self) -> str:
  53. return f"{self.human_message}\n{self.ai_message}"