123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- import logging
- from typing import Any, Dict, Optional
- from embedchain.helpers.json_serializable import JSONSerializable
- class BaseMessage(JSONSerializable):
- """
- The base abstract message class.
- Messages are the inputs and outputs of Models.
- """
- # The string content of the message.
- content: str
- # The created_by of the message. AI, Human, Bot etc.
- created_by: str
- # Any additional info.
- metadata: Dict[str, Any]
- def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None:
- super().__init__()
- self.content = content
- self.created_by = created_by
- self.metadata = metadata
- @property
- def type(self) -> str:
- """Type of the Message, used for serialization."""
- @classmethod
- def is_lc_serializable(cls) -> bool:
- """Return whether this class is serializable."""
- return True
- def __str__(self) -> str:
- return f"{self.created_by}: {self.content}"
- class ChatMessage(JSONSerializable):
- """
- The base abstract chat message class.
- Chat messages are the pair of (question, answer) conversation
- between human and model.
- """
- human_message: Optional[BaseMessage] = None
- ai_message: Optional[BaseMessage] = None
- def add_user_message(self, message: str, metadata: Optional[dict] = None):
- if self.human_message:
- logging.info(
- "Human message already exists in the chat message,\
- overwritting it with new message."
- )
- self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
- def add_ai_message(self, message: str, metadata: Optional[dict] = None):
- if self.ai_message:
- logging.info(
- "AI message already exists in the chat message,\
- overwritting it with new message."
- )
- self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)
- def __str__(self) -> str:
- return f"{self.human_message}\n{self.ai_message}"
|