base.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import logging
  2. from typing import Any, Dict, Generator, List, Optional
  3. from langchain.schema import BaseMessage as LCBaseMessage
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.config.llm.base import (DEFAULT_PROMPT,
  6. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
  7. DOCS_SITE_PROMPT_TEMPLATE)
  8. from embedchain.helpers.json_serializable import JSONSerializable
  9. from embedchain.memory.base import ECChatMemory
  10. from embedchain.memory.message import ChatMessage
  11. class BaseLlm(JSONSerializable):
  12. def __init__(self, config: Optional[BaseLlmConfig] = None):
  13. """Initialize a base LLM class
  14. :param config: LLM configuration option class, defaults to None
  15. :type config: Optional[BaseLlmConfig], optional
  16. """
  17. if config is None:
  18. self.config = BaseLlmConfig()
  19. else:
  20. self.config = config
  21. self.memory = ECChatMemory()
  22. self.is_docs_site_instance = False
  23. self.online = False
  24. self.history: Any = None
  25. def get_llm_model_answer(self):
  26. """
  27. Usually implemented by child class
  28. """
  29. raise NotImplementedError
  30. def set_history(self, history: Any):
  31. """
  32. Provide your own history.
  33. Especially interesting for the query method, which does not internally manage conversation history.
  34. :param history: History to set
  35. :type history: Any
  36. """
  37. self.history = history
  38. def update_history(self, app_id: str):
  39. """Update class history attribute with history in memory (for chat method)"""
  40. chat_history = self.memory.get_recent_memories(app_id=app_id, num_rounds=10)
  41. self.set_history([str(history) for history in chat_history])
  42. def add_history(self, app_id: str, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None):
  43. chat_message = ChatMessage()
  44. chat_message.add_user_message(question, metadata=metadata)
  45. chat_message.add_ai_message(answer, metadata=metadata)
  46. self.memory.add(app_id=app_id, chat_message=chat_message)
  47. self.update_history(app_id=app_id)
  48. def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
  49. """
  50. Generates a prompt based on the given query and context, ready to be
  51. passed to an LLM
  52. :param input_query: The query to use.
  53. :type input_query: str
  54. :param contexts: List of similar documents to the query used as context.
  55. :type contexts: List[str]
  56. :return: The prompt
  57. :rtype: str
  58. """
  59. context_string = (" | ").join(contexts)
  60. web_search_result = kwargs.get("web_search_result", "")
  61. if web_search_result:
  62. context_string = self._append_search_and_context(context_string, web_search_result)
  63. prompt_contains_history = self.config._validate_prompt_history(self.config.prompt)
  64. if prompt_contains_history:
  65. # Prompt contains history
  66. # If there is no history yet, we insert `- no history -`
  67. prompt = self.config.prompt.substitute(
  68. context=context_string, query=input_query, history=self.history or "- no history -"
  69. )
  70. elif self.history and not prompt_contains_history:
  71. # History is present, but not included in the prompt.
  72. # check if it's the default prompt without history
  73. if (
  74. not self.config._validate_prompt_history(self.config.prompt)
  75. and self.config.prompt.template == DEFAULT_PROMPT
  76. ):
  77. # swap in the template with history
  78. prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
  79. context=context_string, query=input_query, history=self.history
  80. )
  81. else:
  82. # If we can't swap in the default, we still proceed but tell users that the history is ignored.
  83. logging.warning(
  84. "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
  85. )
  86. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  87. else:
  88. # basic use case, no history.
  89. prompt = self.config.prompt.substitute(context=context_string, query=input_query)
  90. return prompt
  91. def _append_search_and_context(self, context: str, web_search_result: str) -> str:
  92. """Append web search context to existing context
  93. :param context: Existing context
  94. :type context: str
  95. :param web_search_result: Web search result
  96. :type web_search_result: str
  97. :return: Concatenated web search result
  98. :rtype: str
  99. """
  100. return f"{context}\nWeb Search Result: {web_search_result}"
  101. def get_answer_from_llm(self, prompt: str):
  102. """
  103. Gets an answer based on the given query and context by passing it
  104. to an LLM.
  105. :param prompt: Gets an answer based on the given query and context by passing it to an LLM.
  106. :type prompt: str
  107. :return: The answer.
  108. :rtype: _type_
  109. """
  110. return self.get_llm_model_answer(prompt)
  111. def access_search_and_get_results(self, input_query: str):
  112. """
  113. Search the internet for additional context
  114. :param input_query: search query
  115. :type input_query: str
  116. :return: Search results
  117. :rtype: Unknown
  118. """
  119. try:
  120. from langchain.tools import DuckDuckGoSearchRun
  121. except ImportError:
  122. raise ImportError(
  123. 'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
  124. ) from None
  125. search = DuckDuckGoSearchRun()
  126. logging.info(f"Access search to get answers for {input_query}")
  127. return search.run(input_query)
  128. def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
  129. """Generator to be used as streaming response
  130. :param answer: Answer chunk from llm
  131. :type answer: Any
  132. :yield: Answer chunk from llm
  133. :rtype: Generator[Any, Any, None]
  134. """
  135. streamed_answer = ""
  136. for chunk in answer:
  137. streamed_answer = streamed_answer + chunk
  138. yield chunk
  139. logging.info(f"Answer: {streamed_answer}")
  140. def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
  141. """
  142. Queries the vector database based on the given input query.
  143. Gets relevant doc based on the query and then passes it to an
  144. LLM as context to get the answer.
  145. :param input_query: The query to use.
  146. :type input_query: str
  147. :param contexts: Embeddings retrieved from the database to be used as context.
  148. :type contexts: List[str]
  149. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  150. To persistently use a config, declare it during app init., defaults to None
  151. :type config: Optional[BaseLlmConfig], optional
  152. :param dry_run: A dry run does everything except send the resulting prompt to
  153. the LLM. The purpose is to test the prompt, not the response., defaults to False
  154. :type dry_run: bool, optional
  155. :return: The answer to the query or the dry run result
  156. :rtype: str
  157. """
  158. try:
  159. if config:
  160. # A config instance passed to this method will only be applied temporarily, for one call.
  161. # So we will save the previous config and restore it at the end of the execution.
  162. # For this we use the serializer.
  163. prev_config = self.config.serialize()
  164. self.config = config
  165. if config is not None and config.query_type == "Images":
  166. return contexts
  167. if self.is_docs_site_instance:
  168. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  169. self.config.number_documents = 5
  170. k = {}
  171. if self.online:
  172. k["web_search_result"] = self.access_search_and_get_results(input_query)
  173. prompt = self.generate_prompt(input_query, contexts, **k)
  174. logging.info(f"Prompt: {prompt}")
  175. if dry_run:
  176. return prompt
  177. answer = self.get_answer_from_llm(prompt)
  178. if isinstance(answer, str):
  179. logging.info(f"Answer: {answer}")
  180. return answer
  181. else:
  182. return self._stream_response(answer)
  183. finally:
  184. if config:
  185. # Restore previous config
  186. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  187. def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
  188. """
  189. Queries the vector database on the given input query.
  190. Gets relevant doc based on the query and then passes it to an
  191. LLM as context to get the answer.
  192. Maintains the whole conversation in memory.
  193. :param input_query: The query to use.
  194. :type input_query: str
  195. :param contexts: Embeddings retrieved from the database to be used as context.
  196. :type contexts: List[str]
  197. :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
  198. To persistently use a config, declare it during app init., defaults to None
  199. :type config: Optional[BaseLlmConfig], optional
  200. :param dry_run: A dry run does everything except send the resulting prompt to
  201. the LLM. The purpose is to test the prompt, not the response., defaults to False
  202. :type dry_run: bool, optional
  203. :return: The answer to the query or the dry run result
  204. :rtype: str
  205. """
  206. try:
  207. if config:
  208. # A config instance passed to this method will only be applied temporarily, for one call.
  209. # So we will save the previous config and restore it at the end of the execution.
  210. # For this we use the serializer.
  211. prev_config = self.config.serialize()
  212. self.config = config
  213. if self.is_docs_site_instance:
  214. self.config.prompt = DOCS_SITE_PROMPT_TEMPLATE
  215. self.config.number_documents = 5
  216. k = {}
  217. if self.online:
  218. k["web_search_result"] = self.access_search_and_get_results(input_query)
  219. prompt = self.generate_prompt(input_query, contexts, **k)
  220. logging.info(f"Prompt: {prompt}")
  221. if dry_run:
  222. return prompt
  223. answer = self.get_answer_from_llm(prompt)
  224. if isinstance(answer, str):
  225. logging.info(f"Answer: {answer}")
  226. return answer
  227. else:
  228. # this is a streamed response and needs to be handled differently.
  229. return self._stream_response(answer)
  230. finally:
  231. if config:
  232. # Restore previous config
  233. self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
  234. @staticmethod
  235. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
  236. """
  237. Construct a list of langchain messages
  238. :param prompt: User prompt
  239. :type prompt: str
  240. :param system_prompt: System prompt, defaults to None
  241. :type system_prompt: Optional[str], optional
  242. :return: List of messages
  243. :rtype: List[BaseMessage]
  244. """
  245. from langchain.schema import HumanMessage, SystemMessage
  246. messages = []
  247. if system_prompt:
  248. messages.append(SystemMessage(content=system_prompt))
  249. messages.append(HumanMessage(content=prompt))
  250. return messages