base.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import logging
  2. from collections.abc import Generator
  3. from typing import Any, Optional
  4. from langchain.schema import BaseMessage as LCBaseMessage
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.config.llm.base import (
  7. DEFAULT_PROMPT,
  8. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
  9. DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
  10. DOCS_SITE_PROMPT_TEMPLATE,
  11. )
  12. from embedchain.helpers.json_serializable import JSONSerializable
  13. from embedchain.memory.base import ChatHistory
  14. from embedchain.memory.message import ChatMessage
  15. logger = logging.getLogger(__name__)
  16. class BaseLlm(JSONSerializable):
  17. def __init__(self, config: Optional[BaseLlmConfig] = None):
  18. """Initialize a base LLM class
  19. :param config: LLM configuration option class, defaults to None
  20. :type config: Optional[BaseLlmConfig], optional
  21. """
  22. if config is None:
  23. self.config = BaseLlmConfig()
  24. else:
  25. self.config = config
  26. self.memory = ChatHistory()
  27. self.is_docs_site_instance = False
  28. self.history: Any = None
  29. def get_llm_model_answer(self):
  30. """
  31. Usually implemented by child class
  32. """
  33. raise NotImplementedError
  34. def set_history(self, history: Any):
  35. """
  36. Provide your own history.
  37. Especially interesting for the query method, which does not internally manage conversation history.
  38. :param history: History to set
  39. :type history: Any
  40. """
  41. self.history = history
  42. def update_history(self, app_id: str, session_id: str = "default"):
  43. """Update class history attribute with history in memory (for chat method)"""
  44. chat_history = self.memory.get(app_id=app_id, session_id=session_id, num_rounds=10)
  45. self.set_history([str(history) for history in chat_history])
  46. def add_history(
  47. self,
  48. app_id: str,
  49. question: str,
  50. answer: str,
  51. metadata: Optional[dict[str, Any]] = None,
  52. session_id: str = "default",
  53. ):
  54. chat_message = ChatMessage()
  55. chat_message.add_user_message(question, metadata=metadata)
  56. chat_message.add_ai_message(answer, metadata=metadata)
  57. self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
  58. self.update_history(app_id=app_id, session_id=session_id)
  59. def _format_history(self) -> str:
  60. """Format history to be used in prompt
  61. :return: Formatted history
  62. :rtype: str
  63. """
  64. return "\n".join(self.history)
  65. def _format_memories(self, memories: list[dict]) -> str:
  66. """Format memories to be used in prompt
  67. :param memories: Memories to format
  68. :type memories: list[dict]
  69. :return: Formatted memories
  70. :rtype: str
  71. """
  72. return "\n".join([memory["text"] for memory in memories])
  73. def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
  74. """
  75. Generates a prompt based on the given query and context, ready to be
  76. passed to an LLM
  77. :param input_query: The query to use.
  78. :type input_query: str
  79. :param contexts: List of similar documents to the query used as context.
  80. :type contexts: list[str]
  81. :return: The prompt
  82. :rtype: str
  83. """
  84. context_string = " | ".join(contexts)
  85. web_search_result = kwargs.get("web_search_result", "")
  86. memories = kwargs.get("memories", None)
  87. if web_search_result:
  88. context_string = self._append_search_and_context(context_string, web_search_result)
  89. prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
  90. if prompt_contains_history:
  91. prompt = self.config.prompt.substitute(
  92. context=context_string, query=input_query, history=self._format_history() or "No history"
  93. )
  94. elif self.history and not prompt_contains_history:
  95. # History is present, but not included in the prompt.
  96. # check if it's the default prompt without history
  97. if (
  98. not self.config._validate_prompt_history(self.config.prompt)
  99. and self.config.prompt.template == DEFAULT_PROMPT
  100. ):
  101. if memories:
  102. # swap in the template with Mem0 memory template
  103. prompt = DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE.substitute(
  104. context=context_string,
  105. query=input_query,
  106. history=self._format_history(),
  107. memories=self._format_memories(memories),
  108. )
  109. else:
  110. # swap in the template with history
  111. prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
  112. context=context_string, query=input_query, history=self._format_history()
  113. )
  114. else:
  115. # If we can't swap in the default, we still proceed but tell users that the history is ignored.
  116. logger.warning(
  117. "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
  118. )
  119. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  120. else:
  121. # basic use case, no history.
  122. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  123. return prompt
  124. @staticmethod
  125. def _append_search_and_context(context: str, web_search_result: str) -> str:
  126. """Append web search context to existing context
  127. :param context: Existing context
  128. :type context: str
  129. :param web_search_result: Web search result
  130. :type web_search_result: str
  131. :return: Concatenated web search result
  132. :rtype: str
  133. """
  134. return f"{context}\nWeb Search Result: {web_search_result}"
  135. def get_answer_from_llm(self, prompt: str):
  136. """
  137. Gets an answer based on the given query and context by passing it
  138. to an LLM.
  139. :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
  140. :type prompt: str
  141. :return: The answer.
  142. :rtype: _type_
  143. """
  144. return self.get_llm_model_answer(prompt)
  145. @staticmethod
  146. def access_search_and_get_results(input_query: str):
  147. """
  148. Search the internet for additional context
  149. :param input_query: search query
  150. :type input_query: str
  151. :return: Search results
  152. :rtype: Unknown
  153. """
  154. try:
  155. from langchain.tools import DuckDuckGoSearchRun
  156. except ImportError:
  157. raise ImportError(
  158. "Searching requires extra dependencies. Install with `pip install duckduckgo-search==6.1.5`"
  159. ) from None
  160. search = DuckDuckGoSearchRun()
  161. logger.info(f"Access search to get answers for {input_query}")
  162. return search.run(input_query)
  163. @staticmethod
  164. def _stream_response(answer: Any, token_info: Optional[dict[str, Any]] = None) -> Generator[Any, Any, None]:
  165. """Generator to be used as streaming response
  166. :param answer: Answer chunk from llm
  167. :type answer: Any
  168. :yield: Answer chunk from llm
  169. :rtype: Generator[Any, Any, None]
  170. """
  171. streamed_answer = ""
  172. for chunk in answer:
  173. streamed_answer = streamed_answer + chunk
  174. yield chunk
  175. logger.info(f"Answer: {streamed_answer}")
  176. if token_info:
  177. logger.info(f"Token Info: {token_info}")
  178. def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, memories=None):
  179. """
  180. Queries the vector database based on the given input query.
  181. Gets relevant doc based on the query and then passes it to an
  182. LLM as context to get the answer.
  183. :param input_query: The query to use.
  184. :type input_query: str
  185. :param contexts: Embeddings retrieved from the database to be used as context.
  186. :type contexts: list[str]
  187. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  188. To persistently use a config, declare it during app init., defaults to None
  189. :type config: Optional[BaseLlmConfig], optional
  190. :param dry_run: A dry run does everything except send the resulting prompt to
  191. the LLM. The purpose is to test the prompt, not the response., defaults to False
  192. :type dry_run: bool, optional
  193. :return: The answer to the query or the dry run result
  194. :rtype: str
  195. """
  196. try:
  197. if config:
  198. # A config instance passed to this method will only be applied temporarily, for one call.
  199. # So we will save the previous config and restore it at the end of the execution.
  200. # For this we use the serializer.
  201. prev_config = self.config.serialize()
  202. self.config = config
  203. if config is not None and config.query_type == "Images":
  204. return contexts
  205. if self.is_docs_site_instance:
  206. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  207. self.config.number_documents = 5
  208. k = {}
  209. if self.config.online:
  210. k["web_search_result"] = self.access_search_and_get_results(input_query)
  211. k["memories"] = memories
  212. prompt = self.generate_prompt(input_query, contexts, **k)
  213. logger.info(f"Prompt: {prompt}")
  214. if dry_run:
  215. return prompt
  216. if self.config.token_usage:
  217. answer, token_info = self.get_answer_from_llm(prompt)
  218. else:
  219. answer = self.get_answer_from_llm(prompt)
  220. if isinstance(answer, str):
  221. logger.info(f"Answer: {answer}")
  222. if self.config.token_usage:
  223. return answer, token_info
  224. return answer
  225. else:
  226. if self.config.token_usage:
  227. return self._stream_response(answer, token_info)
  228. return self._stream_response(answer)
  229. finally:
  230. if config:
  231. # Restore previous config
  232. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  233. def chat(
  234. self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
  235. ):
  236. """
  237. Queries the vector database on the given input query.
  238. Gets relevant doc based on the query and then passes it to an
  239. LLM as context to get the answer.
  240. Maintains the whole conversation in memory.
  241. :param input_query: The query to use.
  242. :type input_query: str
  243. :param contexts: Embeddings retrieved from the database to be used as context.
  244. :type contexts: list[str]
  245. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  246. To persistently use a config, declare it during app init., defaults to None
  247. :type config: Optional[BaseLlmConfig], optional
  248. :param dry_run: A dry run does everything except send the resulting prompt to
  249. the LLM. The purpose is to test the prompt, not the response., defaults to False
  250. :type dry_run: bool, optional
  251. :param session_id: Session ID to use for the conversation, defaults to None
  252. :type session_id: str, optional
  253. :return: The answer to the query or the dry run result
  254. :rtype: str
  255. """
  256. try:
  257. if config:
  258. # A config instance passed to this method will only be applied temporarily, for one call.
  259. # So we will save the previous config and restore it at the end of the execution.
  260. # For this we use the serializer.
  261. prev_config = self.config.serialize()
  262. self.config = config
  263. if self.is_docs_site_instance:
  264. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  265. self.config.number_documents = 5
  266. k = {}
  267. if self.config.online:
  268. k["web_search_result"] = self.access_search_and_get_results(input_query)
  269. prompt = self.generate_prompt(input_query, contexts, **k)
  270. logger.info(f"Prompt: {prompt}")
  271. if dry_run:
  272. return prompt
  273. answer, token_info = self.get_answer_from_llm(prompt)
  274. if isinstance(answer, str):
  275. logger.info(f"Answer: {answer}")
  276. return answer, token_info
  277. else:
  278. # this is a streamed response and needs to be handled differently.
  279. return self._stream_response(answer, token_info)
  280. finally:
  281. if config:
  282. # Restore previous config
  283. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  284. @staticmethod
  285. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]:
  286. """
  287. Construct a list of langchain messages
  288. :param prompt: User prompt
  289. :type prompt: str
  290. :param system_prompt: System prompt, defaults to None
  291. :type system_prompt: Optional[str], optional
  292. :return: List of messages
  293. :rtype: list[BaseMessage]
  294. """
  295. from langchain.schema import HumanMessage, SystemMessage
  296. messages = []
  297. if system_prompt:
  298. messages.append(SystemMessage(content=system_prompt))
  299. messages.append(HumanMessage(content=prompt))
  300. return messages