base.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import logging
  2. from collections.abc import Generator
  3. from typing import Any, Optional
  4. from langchain.schema import BaseMessage as LCBaseMessage
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.config.llm.base import (DEFAULT_PROMPT,
  7. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
  8. DOCS_SITE_PROMPT_TEMPLATE)
  9. from embedchain.helpers.json_serializable import JSONSerializable
  10. from embedchain.memory.base import ChatHistory
  11. from embedchain.memory.message import ChatMessage
  12. logger = logging.getLogger(__name__)
  13. class BaseLlm(JSONSerializable):
  14. def __init__(self, config: Optional[BaseLlmConfig] = None):
  15. """Initialize a base LLM class
  16. :param config: LLM configuration option class, defaults to None
  17. :type config: Optional[BaseLlmConfig], optional
  18. """
  19. if config is None:
  20. self.config = BaseLlmConfig()
  21. else:
  22. self.config = config
  23. self.memory = ChatHistory()
  24. self.is_docs_site_instance = False
  25. self.online = False
  26. self.history: Any = None
  27. def get_llm_model_answer(self):
  28. """
  29. Usually implemented by child class
  30. """
  31. raise NotImplementedError
  32. def set_history(self, history: Any):
  33. """
  34. Provide your own history.
  35. Especially interesting for the query method, which does not internally manage conversation history.
  36. :param history: History to set
  37. :type history: Any
  38. """
  39. self.history = history
  40. def update_history(self, app_id: str, session_id: str = "default"):
  41. """Update class history attribute with history in memory (for chat method)"""
  42. chat_history = self.memory.get(app_id=app_id, session_id=session_id, num_rounds=10)
  43. self.set_history([str(history) for history in chat_history])
  44. def add_history(
  45. self,
  46. app_id: str,
  47. question: str,
  48. answer: str,
  49. metadata: Optional[dict[str, Any]] = None,
  50. session_id: str = "default",
  51. ):
  52. chat_message = ChatMessage()
  53. chat_message.add_user_message(question, metadata=metadata)
  54. chat_message.add_ai_message(answer, metadata=metadata)
  55. self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
  56. self.update_history(app_id=app_id, session_id=session_id)
  57. def _format_history(self) -> str:
  58. """Format history to be used in prompt
  59. :return: Formatted history
  60. :rtype: str
  61. """
  62. return "\n".join(self.history)
  63. def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
  64. """
  65. Generates a prompt based on the given query and context, ready to be
  66. passed to an LLM
  67. :param input_query: The query to use.
  68. :type input_query: str
  69. :param contexts: List of similar documents to the query used as context.
  70. :type contexts: list[str]
  71. :return: The prompt
  72. :rtype: str
  73. """
  74. context_string = " | ".join(contexts)
  75. web_search_result = kwargs.get("web_search_result", "")
  76. if web_search_result:
  77. context_string = self._append_search_and_context(context_string, web_search_result)
  78. prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
  79. if prompt_contains_history:
  80. prompt = self.config.prompt.substitute(
  81. context=context_string, query=input_query, history=self._format_history() or "No history"
  82. )
  83. elif self.history and not prompt_contains_history:
  84. # History is present, but not included in the prompt.
  85. # check if it's the default prompt without history
  86. if (
  87. not self.config._validate_prompt_history(self.config.prompt)
  88. and self.config.prompt.template == DEFAULT_PROMPT
  89. ):
  90. # swap in the template with history
  91. prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
  92. context=context_string, query=input_query, history=self._format_history()
  93. )
  94. else:
  95. # If we can't swap in the default, we still proceed but tell users that the history is ignored.
  96. logger.warning(
  97. "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
  98. )
  99. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  100. else:
  101. # basic use case, no history.
  102. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  103. return prompt
  104. @staticmethod
  105. def _append_search_and_context(context: str, web_search_result: str) -> str:
  106. """Append web search context to existing context
  107. :param context: Existing context
  108. :type context: str
  109. :param web_search_result: Web search result
  110. :type web_search_result: str
  111. :return: Concatenated web search result
  112. :rtype: str
  113. """
  114. return f"{context}\nWeb Search Result: {web_search_result}"
  115. def get_answer_from_llm(self, prompt: str):
  116. """
  117. Gets an answer based on the given query and context by passing it
  118. to an LLM.
  119. :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
  120. :type prompt: str
  121. :return: The answer.
  122. :rtype: _type_
  123. """
  124. return self.get_llm_model_answer(prompt)
  125. @staticmethod
  126. def access_search_and_get_results(input_query: str):
  127. """
  128. Search the internet for additional context
  129. :param input_query: search query
  130. :type input_query: str
  131. :return: Search results
  132. :rtype: Unknown
  133. """
  134. try:
  135. from langchain.tools import DuckDuckGoSearchRun
  136. except ImportError:
  137. raise ImportError(
  138. 'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
  139. ) from None
  140. search = DuckDuckGoSearchRun()
  141. logger.info(f"Access search to get answers for {input_query}")
  142. return search.run(input_query)
  143. @staticmethod
  144. def _stream_response(answer: Any) -> Generator[Any, Any, None]:
  145. """Generator to be used as streaming response
  146. :param answer: Answer chunk from llm
  147. :type answer: Any
  148. :yield: Answer chunk from llm
  149. :rtype: Generator[Any, Any, None]
  150. """
  151. streamed_answer = ""
  152. for chunk in answer:
  153. streamed_answer = streamed_answer + chunk
  154. yield chunk
  155. logger.info(f"Answer: {streamed_answer}")
  156. def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
  157. """
  158. Queries the vector database based on the given input query.
  159. Gets relevant doc based on the query and then passes it to an
  160. LLM as context to get the answer.
  161. :param input_query: The query to use.
  162. :type input_query: str
  163. :param contexts: Embeddings retrieved from the database to be used as context.
  164. :type contexts: list[str]
  165. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  166. To persistently use a config, declare it during app init., defaults to None
  167. :type config: Optional[BaseLlmConfig], optional
  168. :param dry_run: A dry run does everything except send the resulting prompt to
  169. the LLM. The purpose is to test the prompt, not the response., defaults to False
  170. :type dry_run: bool, optional
  171. :return: The answer to the query or the dry run result
  172. :rtype: str
  173. """
  174. try:
  175. if config:
  176. # A config instance passed to this method will only be applied temporarily, for one call.
  177. # So we will save the previous config and restore it at the end of the execution.
  178. # For this we use the serializer.
  179. prev_config = self.config.serialize()
  180. self.config = config
  181. if config is not None and config.query_type == "Images":
  182. return contexts
  183. if self.is_docs_site_instance:
  184. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  185. self.config.number_documents = 5
  186. k = {}
  187. if self.online:
  188. k["web_search_result"] = self.access_search_and_get_results(input_query)
  189. prompt = self.generate_prompt(input_query, contexts, **k)
  190. logger.info(f"Prompt: {prompt}")
  191. if dry_run:
  192. return prompt
  193. answer = self.get_answer_from_llm(prompt)
  194. if isinstance(answer, str):
  195. logger.info(f"Answer: {answer}")
  196. return answer
  197. else:
  198. return self._stream_response(answer)
  199. finally:
  200. if config:
  201. # Restore previous config
  202. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  203. def chat(
  204. self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
  205. ):
  206. """
  207. Queries the vector database on the given input query.
  208. Gets relevant doc based on the query and then passes it to an
  209. LLM as context to get the answer.
  210. Maintains the whole conversation in memory.
  211. :param input_query: The query to use.
  212. :type input_query: str
  213. :param contexts: Embeddings retrieved from the database to be used as context.
  214. :type contexts: list[str]
  215. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  216. To persistently use a config, declare it during app init., defaults to None
  217. :type config: Optional[BaseLlmConfig], optional
  218. :param dry_run: A dry run does everything except send the resulting prompt to
  219. the LLM. The purpose is to test the prompt, not the response., defaults to False
  220. :type dry_run: bool, optional
  221. :param session_id: Session ID to use for the conversation, defaults to None
  222. :type session_id: str, optional
  223. :return: The answer to the query or the dry run result
  224. :rtype: str
  225. """
  226. try:
  227. if config:
  228. # A config instance passed to this method will only be applied temporarily, for one call.
  229. # So we will save the previous config and restore it at the end of the execution.
  230. # For this we use the serializer.
  231. prev_config = self.config.serialize()
  232. self.config = config
  233. if self.is_docs_site_instance:
  234. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  235. self.config.number_documents = 5
  236. k = {}
  237. if self.online:
  238. k["web_search_result"] = self.access_search_and_get_results(input_query)
  239. prompt = self.generate_prompt(input_query, contexts, **k)
  240. logger.info(f"Prompt: {prompt}")
  241. if dry_run:
  242. return prompt
  243. answer = self.get_answer_from_llm(prompt)
  244. if isinstance(answer, str):
  245. logger.info(f"Answer: {answer}")
  246. return answer
  247. else:
  248. # this is a streamed response and needs to be handled differently.
  249. return self._stream_response(answer)
  250. finally:
  251. if config:
  252. # Restore previous config
  253. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  254. @staticmethod
  255. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]:
  256. """
  257. Construct a list of langchain messages
  258. :param prompt: User prompt
  259. :type prompt: str
  260. :param system_prompt: System prompt, defaults to None
  261. :type system_prompt: Optional[str], optional
  262. :return: List of messages
  263. :rtype: list[BaseMessage]
  264. """
  265. from langchain.schema import HumanMessage, SystemMessage
  266. messages = []
  267. if system_prompt:
  268. messages.append(SystemMessage(content=system_prompt))
  269. messages.append(HumanMessage(content=prompt))
  270. return messages