base.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import logging
  2. from typing import Any, Dict, Generator, List, Optional
  3. from langchain.schema import BaseMessage as LCBaseMessage
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
  6. from embedchain.helpers.json_serializable import JSONSerializable
  7. from embedchain.memory.base import ChatHistory
  8. from embedchain.memory.message import ChatMessage
  9. class BaseLlm(JSONSerializable):
  10. def __init__(self, config: Optional[BaseLlmConfig] = None):
  11. """Initialize a base LLM class
  12. :param config: LLM configuration option class, defaults to None
  13. :type config: Optional[BaseLlmConfig], optional
  14. """
  15. if config is None:
  16. self.config = BaseLlmConfig()
  17. else:
  18. self.config = config
  19. self.memory = ChatHistory()
  20. self.is_docs_site_instance = False
  21. self.online = False
  22. self.history: Any = None
  23. def get_llm_model_answer(self):
  24. """
  25. Usually implemented by child class
  26. """
  27. raise NotImplementedError
  28. def set_history(self, history: Any):
  29. """
  30. Provide your own history.
  31. Especially interesting for the query method, which does not internally manage conversation history.
  32. :param history: History to set
  33. :type history: Any
  34. """
  35. self.history = history
  36. def update_history(self, app_id: str, session_id: str = "default"):
  37. """Update class history attribute with history in memory (for chat method)"""
  38. chat_history = self.memory.get(app_id=app_id, session_id=session_id, num_rounds=10)
  39. self.set_history([str(history) for history in chat_history])
  40. def add_history(
  41. self,
  42. app_id: str,
  43. question: str,
  44. answer: str,
  45. metadata: Optional[Dict[str, Any]] = None,
  46. session_id: str = "default",
  47. ):
  48. chat_message = ChatMessage()
  49. chat_message.add_user_message(question, metadata=metadata)
  50. chat_message.add_ai_message(answer, metadata=metadata)
  51. self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
  52. self.update_history(app_id=app_id, session_id=session_id)
  53. def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
  54. """
  55. Generates a prompt based on the given query and context, ready to be
  56. passed to an LLM
  57. :param input_query: The query to use.
  58. :type input_query: str
  59. :param contexts: List of similar documents to the query used as context.
  60. :type contexts: List[str]
  61. :return: The prompt
  62. :rtype: str
  63. """
  64. context_string = " | ".join(contexts)
  65. web_search_result = kwargs.get("web_search_result", "")
  66. if web_search_result:
  67. context_string = self._append_search_and_context(context_string, web_search_result)
  68. prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
  69. if prompt_contains_history:
  70. # Prompt contains history
  71. # If there is no history yet, we insert `- no history -`
  72. prompt = self.config.prompt.substitute(
  73. context=context_string, query=input_query, history=self.history or "- no history -"
  74. )
  75. elif self.history and not prompt_contains_history:
  76. # History is present, but not included in the prompt.
  77. # check if it's the default prompt without history
  78. if (
  79. not self.config._validate_prompt_history(self.config.prompt)
  80. and self.config.prompt.template == DEFAULT_PROMPT
  81. ):
  82. # swap in the template with history
  83. prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
  84. context=context_string, query=input_query, history=self.history
  85. )
  86. else:
  87. # If we can't swap in the default, we still proceed but tell users that the history is ignored.
  88. logging.warning(
  89. "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
  90. )
  91. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  92. else:
  93. # basic use case, no history.
  94. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  95. return prompt
  96. @staticmethod
  97. def _append_search_and_context(context: str, web_search_result: str) -> str:
  98. """Append web search context to existing context
  99. :param context: Existing context
  100. :type context: str
  101. :param web_search_result: Web search result
  102. :type web_search_result: str
  103. :return: Concatenated web search result
  104. :rtype: str
  105. """
  106. return f"{context}\nWeb Search Result: {web_search_result}"
  107. def get_answer_from_llm(self, prompt: str):
  108. """
  109. Gets an answer based on the given query and context by passing it
  110. to an LLM.
  111. :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
  112. :type prompt: str
  113. :return: The answer.
  114. :rtype: _type_
  115. """
  116. return self.get_llm_model_answer(prompt)
  117. @staticmethod
  118. def access_search_and_get_results(input_query: str):
  119. """
  120. Search the internet for additional context
  121. :param input_query: search query
  122. :type input_query: str
  123. :return: Search results
  124. :rtype: Unknown
  125. """
  126. try:
  127. from langchain.tools import DuckDuckGoSearchRun
  128. except ImportError:
  129. raise ImportError(
  130. 'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
  131. ) from None
  132. search = DuckDuckGoSearchRun()
  133. logging.info(f"Access search to get answers for {input_query}")
  134. return search.run(input_query)
  135. @staticmethod
  136. def _stream_response(answer: Any) -> Generator[Any, Any, None]:
  137. """Generator to be used as streaming response
  138. :param answer: Answer chunk from llm
  139. :type answer: Any
  140. :yield: Answer chunk from llm
  141. :rtype: Generator[Any, Any, None]
  142. """
  143. streamed_answer = ""
  144. for chunk in answer:
  145. streamed_answer = streamed_answer + chunk
  146. yield chunk
  147. logging.info(f"Answer: {streamed_answer}")
  148. def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
  149. """
  150. Queries the vector database based on the given input query.
  151. Gets relevant doc based on the query and then passes it to an
  152. LLM as context to get the answer.
  153. :param input_query: The query to use.
  154. :type input_query: str
  155. :param contexts: Embeddings retrieved from the database to be used as context.
  156. :type contexts: List[str]
  157. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  158. To persistently use a config, declare it during app init., defaults to None
  159. :type config: Optional[BaseLlmConfig], optional
  160. :param dry_run: A dry run does everything except send the resulting prompt to
  161. the LLM. The purpose is to test the prompt, not the response., defaults to False
  162. :type dry_run: bool, optional
  163. :return: The answer to the query or the dry run result
  164. :rtype: str
  165. """
  166. try:
  167. if config:
  168. # A config instance passed to this method will only be applied temporarily, for one call.
  169. # So we will save the previous config and restore it at the end of the execution.
  170. # For this we use the serializer.
  171. prev_config = self.config.serialize()
  172. self.config = config
  173. if config is not None and config.query_type == "Images":
  174. return contexts
  175. if self.is_docs_site_instance:
  176. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  177. self.config.number_documents = 5
  178. k = {}
  179. if self.online:
  180. k["web_search_result"] = self.access_search_and_get_results(input_query)
  181. prompt = self.generate_prompt(input_query, contexts, **k)
  182. logging.info(f"Prompt: {prompt}")
  183. if dry_run:
  184. return prompt
  185. answer = self.get_answer_from_llm(prompt)
  186. if isinstance(answer, str):
  187. logging.info(f"Answer: {answer}")
  188. return answer
  189. else:
  190. return self._stream_response(answer)
  191. finally:
  192. if config:
  193. # Restore previous config
  194. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  195. def chat(
  196. self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
  197. ):
  198. """
  199. Queries the vector database on the given input query.
  200. Gets relevant doc based on the query and then passes it to an
  201. LLM as context to get the answer.
  202. Maintains the whole conversation in memory.
  203. :param input_query: The query to use.
  204. :type input_query: str
  205. :param contexts: Embeddings retrieved from the database to be used as context.
  206. :type contexts: List[str]
  207. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  208. To persistently use a config, declare it during app init., defaults to None
  209. :type config: Optional[BaseLlmConfig], optional
  210. :param dry_run: A dry run does everything except send the resulting prompt to
  211. the LLM. The purpose is to test the prompt, not the response., defaults to False
  212. :type dry_run: bool, optional
  213. :param session_id: Session ID to use for the conversation, defaults to None
  214. :type session_id: str, optional
  215. :return: The answer to the query or the dry run result
  216. :rtype: str
  217. """
  218. try:
  219. if config:
  220. # A config instance passed to this method will only be applied temporarily, for one call.
  221. # So we will save the previous config and restore it at the end of the execution.
  222. # For this we use the serializer.
  223. prev_config = self.config.serialize()
  224. self.config = config
  225. if self.is_docs_site_instance:
  226. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  227. self.config.number_documents = 5
  228. k = {}
  229. if self.online:
  230. k["web_search_result"] = self.access_search_and_get_results(input_query)
  231. prompt = self.generate_prompt(input_query, contexts, **k)
  232. logging.info(f"Prompt: {prompt}")
  233. if dry_run:
  234. return prompt
  235. answer = self.get_answer_from_llm(prompt)
  236. if isinstance(answer, str):
  237. logging.info(f"Answer: {answer}")
  238. return answer
  239. else:
  240. # this is a streamed response and needs to be handled differently.
  241. return self._stream_response(answer)
  242. finally:
  243. if config:
  244. # Restore previous config
  245. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  246. @staticmethod
  247. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
  248. """
  249. Construct a list of langchain messages
  250. :param prompt: User prompt
  251. :type prompt: str
  252. :param system_prompt: System prompt, defaults to None
  253. :type system_prompt: Optional[str], optional
  254. :return: List of messages
  255. :rtype: List[BaseMessage]
  256. """
  257. from langchain.schema import HumanMessage, SystemMessage
  258. messages = []
  259. if system_prompt:
  260. messages.append(SystemMessage(content=system_prompt))
  261. messages.append(HumanMessage(content=prompt))
  262. return messages