base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import logging
  2. from typing import Any, Dict, Generator, List, Optional
  3. from langchain.memory import ConversationBufferMemory
  4. from langchain.schema import BaseMessage
  5. from embedchain.config import BaseLlmConfig
  6. from embedchain.config.llm.base_llm_config import (
  7. DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
  8. DOCS_SITE_PROMPT_TEMPLATE)
  9. from embedchain.helper.json_serializable import JSONSerializable
  10. class BaseLlm(JSONSerializable):
  11. def __init__(self, config: Optional[BaseLlmConfig] = None):
  12. """Initialize a base LLM class
  13. :param config: LLM configuration option class, defaults to None
  14. :type config: Optional[BaseLlmConfig], optional
  15. """
  16. if config is None:
  17. self.config = BaseLlmConfig()
  18. else:
  19. self.config = config
  20. self.memory = ConversationBufferMemory()
  21. self.is_docs_site_instance = False
  22. self.online = False
  23. self.history: Any = None
  24. def get_llm_model_answer(self):
  25. """
  26. Usually implemented by child class
  27. """
  28. raise NotImplementedError
  29. def set_history(self, history: Any):
  30. """
  31. Provide your own history.
  32. Especially interesting for the query method, which does not internally manage conversation history.
  33. :param history: History to set
  34. :type history: Any
  35. """
  36. self.history = history
  37. def update_history(self):
  38. """Update class history attribute with history in memory (for chat method)"""
  39. chat_history = self.memory.load_memory_variables({})["history"]
  40. if chat_history:
  41. self.set_history(chat_history)
  42. def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
  43. """
  44. Generates a prompt based on the given query and context, ready to be
  45. passed to an LLM
  46. :param input_query: The query to use.
  47. :type input_query: str
  48. :param contexts: List of similar documents to the query used as context.
  49. :type contexts: List[str]
  50. :return: The prompt
  51. :rtype: str
  52. """
  53. context_string = (" | ").join(contexts)
  54. web_search_result = kwargs.get("web_search_result", "")
  55. if web_search_result:
  56. context_string = self._append_search_and_context(context_string, web_search_result)
  57. if not self.history:
  58. prompt = self.config.template.substitute(context=context_string, query=input_query)
  59. else:
  60. # check if it's the default template without history
  61. if (
  62. not self.config._validate_template_history(self.config.template)
  63. and self.config.template.template == DEFAULT_PROMPT
  64. ):
  65. # swap in the template with history
  66. prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
  67. context=context_string, query=input_query, history=self.history
  68. )
  69. elif not self.config._validate_template_history(self.config.template):
  70. logging.warning("Template does not include `$history` key. History is not included in prompt.")
  71. prompt = self.config.template.substitute(context=context_string, query=input_query)
  72. else:
  73. prompt = self.config.template.substitute(
  74. context=context_string, query=input_query, history=self.history
  75. )
  76. return prompt
  77. def _append_search_and_context(self, context: str, web_search_result: str) -> str:
  78. """Append web search context to existing context
  79. :param context: Existing context
  80. :type context: str
  81. :param web_search_result: Web search result
  82. :type web_search_result: str
  83. :return: Concatenated web search result
  84. :rtype: str
  85. """
  86. return f"{context}\nWeb Search Result: {web_search_result}"
  87. def get_answer_from_llm(self, prompt: str):
  88. """
  89. Gets an answer based on the given query and context by passing it
  90. to an LLM.
  91. :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
  92. :type prompt: str
  93. :return: The answer.
  94. :rtype: _type_
  95. """
  96. return self.get_llm_model_answer(prompt)
  97. def access_search_and_get_results(self, input_query: str):
  98. """
  99. Search the internet for additional context
  100. :param input_query: search query
  101. :type input_query: str
  102. :return: Search results
  103. :rtype: Unknown
  104. """
  105. from langchain.tools import DuckDuckGoSearchRun
  106. search = DuckDuckGoSearchRun()
  107. logging.info(f"Access search to get answers for {input_query}")
  108. return search.run(input_query)
  109. def _stream_query_response(self, answer: Any) -> Generator[Any, Any, None]:
  110. """Generator to be used as streaming response
  111. :param answer: Answer chunk from llm
  112. :type answer: Any
  113. :yield: Answer chunk from llm
  114. :rtype: Generator[Any, Any, None]
  115. """
  116. streamed_answer = ""
  117. for chunk in answer:
  118. streamed_answer = streamed_answer + chunk
  119. yield chunk
  120. logging.info(f"Answer: {streamed_answer}")
  121. def _stream_chat_response(self, answer: Any) -> Generator[Any, Any, None]:
  122. """Generator to be used as streaming response
  123. :param answer: Answer chunk from llm
  124. :type answer: Any
  125. :yield: Answer chunk from llm
  126. :rtype: Generator[Any, Any, None]
  127. """
  128. streamed_answer = ""
  129. for chunk in answer:
  130. streamed_answer = streamed_answer + chunk
  131. yield chunk
  132. self.memory.chat_memory.add_ai_message(streamed_answer)
  133. logging.info(f"Answer: {streamed_answer}")
  134. def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
  135. """
  136. Queries the vector database based on the given input query.
  137. Gets relevant doc based on the query and then passes it to an
  138. LLM as context to get the answer.
  139. :param input_query: The query to use.
  140. :type input_query: str
  141. :param contexts: Embeddings retrieved from the database to be used as context.
  142. :type contexts: List[str]
  143. :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
  144. To persistently use a config, declare it during app init., defaults to None
  145. :type config: Optional[BaseLlmConfig], optional
  146. :param dry_run: A dry run does everything except send the resulting prompt to
  147. the LLM. The purpose is to test the prompt, not the response., defaults to False
  148. :type dry_run: bool, optional
  149. :return: The answer to the query or the dry run result
  150. :rtype: str
  151. """
  152. try:
  153. if config:
  154. # A config instance passed to this method will only be applied temporarily, for one call.
  155. # So we will save the previous config and restore it at the end of the execution.
  156. # For this we use the serializer.
  157. prev_config = self.config.serialize()
  158. self.config = config
  159. if self.is_docs_site_instance:
  160. self.config.template = DOCS_SITE_PROMPT_TEMPLATE
  161. self.config.number_documents = 5
  162. k = {}
  163. if self.online:
  164. k["web_search_result"] = self.access_search_and_get_results(input_query)
  165. prompt = self.generate_prompt(input_query, contexts, **k)
  166. logging.info(f"Prompt: {prompt}")
  167. if dry_run:
  168. return prompt
  169. answer = self.get_answer_from_llm(prompt)
  170. if isinstance(answer, str):
  171. logging.info(f"Answer: {answer}")
  172. return answer
  173. else:
  174. return self._stream_query_response(answer)
  175. finally:
  176. if config:
  177. # Restore previous config
  178. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  179. def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
  180. """
  181. Queries the vector database on the given input query.
  182. Gets relevant doc based on the query and then passes it to an
  183. LLM as context to get the answer.
  184. Maintains the whole conversation in memory.
  185. :param input_query: The query to use.
  186. :type input_query: str
  187. :param contexts: Embeddings retrieved from the database to be used as context.
  188. :type contexts: List[str]
  189. :param config: The `LlmConfig` instance to use as configuration options. This is used for one method call.
  190. To persistently use a config, declare it during app init., defaults to None
  191. :type config: Optional[BaseLlmConfig], optional
  192. :param dry_run: A dry run does everything except send the resulting prompt to
  193. the LLM. The purpose is to test the prompt, not the response., defaults to False
  194. :type dry_run: bool, optional
  195. :return: The answer to the query or the dry run result
  196. :rtype: str
  197. """
  198. try:
  199. if config:
  200. # A config instance passed to this method will only be applied temporarily, for one call.
  201. # So we will save the previous config and restore it at the end of the execution.
  202. # For this we use the serializer.
  203. prev_config = self.config.serialize()
  204. self.config = config
  205. if self.is_docs_site_instance:
  206. self.config.template = DOCS_SITE_PROMPT_TEMPLATE
  207. self.config.number_documents = 5
  208. k = {}
  209. if self.online:
  210. k["web_search_result"] = self.access_search_and_get_results(input_query)
  211. self.update_history()
  212. prompt = self.generate_prompt(input_query, contexts, **k)
  213. logging.info(f"Prompt: {prompt}")
  214. if dry_run:
  215. return prompt
  216. answer = self.get_answer_from_llm(prompt)
  217. self.memory.chat_memory.add_user_message(input_query)
  218. if isinstance(answer, str):
  219. self.memory.chat_memory.add_ai_message(answer)
  220. logging.info(f"Answer: {answer}")
  221. # NOTE: Adding to history before and after. This could be seen as redundant.
  222. # If we change it, we have to change the tests (no big deal).
  223. self.update_history()
  224. return answer
  225. else:
  226. # this is a streamed response and needs to be handled differently.
  227. return self._stream_chat_response(answer)
  228. finally:
  229. if config:
  230. # Restore previous config
  231. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  232. @staticmethod
  233. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
  234. """
  235. Construct a list of langchain messages
  236. :param prompt: User prompt
  237. :type prompt: str
  238. :param system_prompt: System prompt, defaults to None
  239. :type system_prompt: Optional[str], optional
  240. :return: List of messages
  241. :rtype: List[BaseMessage]
  242. """
  243. from langchain.schema import HumanMessage, SystemMessage
  244. messages = []
  245. if system_prompt:
  246. messages.append(SystemMessage(content=system_prompt))
  247. messages.append(HumanMessage(content=prompt))
  248. return messages