base_llm.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import logging
  2. from typing import List, Optional
  3. from langchain.memory import ConversationBufferMemory
  4. from langchain.schema import BaseMessage
  5. from embedchain.helper_classes.json_serializable import JSONSerializable
  6. from embedchain.config import BaseLlmConfig
  7. from embedchain.config.llm.base_llm_config import (
  8. DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
  9. DOCS_SITE_PROMPT_TEMPLATE)
  10. class BaseLlm(JSONSerializable):
  11. def __init__(self, config: Optional[BaseLlmConfig] = None):
  12. if config is None:
  13. self.config = BaseLlmConfig()
  14. else:
  15. self.config = config
  16. self.memory = ConversationBufferMemory()
  17. self.is_docs_site_instance = False
  18. self.online = False
  19. self.history: any = None
  20. def get_llm_model_answer(self):
  21. """
  22. Usually implemented by child class
  23. """
  24. raise NotImplementedError
  25. def set_history(self, history: any):
  26. self.history = history
  27. def update_history(self):
  28. chat_history = self.memory.load_memory_variables({})["history"]
  29. if chat_history:
  30. self.set_history(chat_history)
  31. def generate_prompt(self, input_query, contexts, **kwargs):
  32. """
  33. Generates a prompt based on the given query and context, ready to be
  34. passed to an LLM
  35. :param input_query: The query to use.
  36. :param contexts: List of similar documents to the query used as context.
  37. :param config: Optional. The `QueryConfig` instance to use as
  38. configuration options.
  39. :return: The prompt
  40. """
  41. context_string = (" | ").join(contexts)
  42. web_search_result = kwargs.get("web_search_result", "")
  43. if web_search_result:
  44. context_string = self._append_search_and_context(context_string, web_search_result)
  45. if not self.history:
  46. prompt = self.config.template.substitute(context=context_string, query=input_query)
  47. else:
  48. # check if it's the default template without history
  49. if (
  50. not self.config._validate_template_history(self.config.template)
  51. and self.config.template.template == DEFAULT_PROMPT
  52. ):
  53. # swap in the template with history
  54. prompt = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE.substitute(
  55. context=context_string, query=input_query, history=self.history
  56. )
  57. elif not self.config._validate_template_history(self.config.template):
  58. logging.warning("Template does not include `$history` key. History is not included in prompt.")
  59. prompt = self.config.template.substitute(context=context_string, query=input_query)
  60. else:
  61. prompt = self.config.template.substitute(
  62. context=context_string, query=input_query, history=self.history
  63. )
  64. return prompt
  65. def _append_search_and_context(self, context, web_search_result):
  66. return f"{context}\nWeb Search Result: {web_search_result}"
  67. def get_answer_from_llm(self, prompt):
  68. """
  69. Gets an answer based on the given query and context by passing it
  70. to an LLM.
  71. :param query: The query to use.
  72. :param context: Similar documents to the query used as context.
  73. :return: The answer.
  74. """
  75. return self.get_llm_model_answer(prompt)
  76. def access_search_and_get_results(self, input_query):
  77. from langchain.tools import DuckDuckGoSearchRun
  78. search = DuckDuckGoSearchRun()
  79. logging.info(f"Access search to get answers for {input_query}")
  80. return search.run(input_query)
  81. def _stream_query_response(self, answer):
  82. streamed_answer = ""
  83. for chunk in answer:
  84. streamed_answer = streamed_answer + chunk
  85. yield chunk
  86. logging.info(f"Answer: {streamed_answer}")
  87. def _stream_chat_response(self, answer):
  88. streamed_answer = ""
  89. for chunk in answer:
  90. streamed_answer = streamed_answer + chunk
  91. yield chunk
  92. self.memory.chat_memory.add_ai_message(streamed_answer)
  93. logging.info(f"Answer: {streamed_answer}")
  94. def query(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
  95. """
  96. Queries the vector database based on the given input query.
  97. Gets relevant doc based on the query and then passes it to an
  98. LLM as context to get the answer.
  99. :param input_query: The query to use.
  100. :param config: Optional. The `LlmConfig` instance to use as configuration options.
  101. This is used for one method call. To persistently use a config, declare it during app init.
  102. :param dry_run: Optional. A dry run does everything except send the resulting prompt to
  103. the LLM. The purpose is to test the prompt, not the response.
  104. You can use it to test your prompt, including the context provided
  105. by the vector database's doc retrieval.
  106. The only thing the dry run does not consider is the cut-off due to
  107. the `max_tokens` parameter.
  108. :param where: Optional. A dictionary of key-value pairs to filter the database results.
  109. :return: The answer to the query.
  110. """
  111. query_config = config or self.config
  112. if self.is_docs_site_instance:
  113. query_config.template = DOCS_SITE_PROMPT_TEMPLATE
  114. query_config.number_documents = 5
  115. k = {}
  116. if self.online:
  117. k["web_search_result"] = self.access_search_and_get_results(input_query)
  118. prompt = self.generate_prompt(input_query, contexts, **k)
  119. logging.info(f"Prompt: {prompt}")
  120. if dry_run:
  121. return prompt
  122. answer = self.get_answer_from_llm(prompt)
  123. if isinstance(answer, str):
  124. logging.info(f"Answer: {answer}")
  125. return answer
  126. else:
  127. return self._stream_query_response(answer)
  128. def chat(self, input_query, contexts, config: BaseLlmConfig = None, dry_run=False, where=None):
  129. """
  130. Queries the vector database on the given input query.
  131. Gets relevant doc based on the query and then passes it to an
  132. LLM as context to get the answer.
  133. Maintains the whole conversation in memory.
  134. :param input_query: The query to use.
  135. :param config: Optional. The `LlmConfig` instance to use as configuration options.
  136. This is used for one method call. To persistently use a config, declare it during app init.
  137. :param dry_run: Optional. A dry run does everything except send the resulting prompt to
  138. the LLM. The purpose is to test the prompt, not the response.
  139. You can use it to test your prompt, including the context provided
  140. by the vector database's doc retrieval.
  141. The only thing the dry run does not consider is the cut-off due to
  142. the `max_tokens` parameter.
  143. :param where: Optional. A dictionary of key-value pairs to filter the database results.
  144. :return: The answer to the query.
  145. """
  146. query_config = config or self.config
  147. if self.is_docs_site_instance:
  148. query_config.template = DOCS_SITE_PROMPT_TEMPLATE
  149. query_config.number_documents = 5
  150. k = {}
  151. if self.online:
  152. k["web_search_result"] = self.access_search_and_get_results(input_query)
  153. self.update_history()
  154. prompt = self.generate_prompt(input_query, contexts, **k)
  155. logging.info(f"Prompt: {prompt}")
  156. if dry_run:
  157. return prompt
  158. answer = self.get_answer_from_llm(prompt)
  159. self.memory.chat_memory.add_user_message(input_query)
  160. if isinstance(answer, str):
  161. self.memory.chat_memory.add_ai_message(answer)
  162. logging.info(f"Answer: {answer}")
  163. # NOTE: Adding to history before and after. This could be seen as redundant.
  164. # If we change it, we have to change the tests (no big deal).
  165. self.update_history()
  166. return answer
  167. else:
  168. # this is a streamed response and needs to be handled differently.
  169. return self._stream_chat_response(answer)
  170. @staticmethod
  171. def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
  172. from langchain.schema import HumanMessage, SystemMessage
  173. messages = []
  174. if system_prompt:
  175. messages.append(SystemMessage(content=system_prompt))
  176. messages.append(HumanMessage(content=prompt))
  177. return messages