base.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import json
  2. import logging
  3. import re
  4. from string import Template
  5. from typing import Any, Mapping, Optional, Dict, Union
  6. import httpx
  7. from embedchain.config.base_config import BaseConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. logger = logging.getLogger(__name__)
  10. DEFAULT_PROMPT = """
  11. You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow:
  12. 1. Refrain from explicitly mentioning the context provided in your response.
  13. 2. The context should silently guide your answers without being directly acknowledged.
  14. 3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
  15. Context information:
  16. ----------------------
  17. $context
  18. ----------------------
  19. Query: $query
  20. Answer:
  21. """ # noqa:E501
  22. DEFAULT_PROMPT_WITH_HISTORY = """
  23. You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history with the user. Make sure to use relevant context from conversation history as needed.
  24. Here are some guidelines to follow:
  25. 1. Refrain from explicitly mentioning the context provided in your response.
  26. 2. The context should silently guide your answers without being directly acknowledged.
  27. 3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
  28. Context information:
  29. ----------------------
  30. $context
  31. ----------------------
  32. Conversation history:
  33. ----------------------
  34. $history
  35. ----------------------
  36. Query: $query
  37. Answer:
  38. """ # noqa:E501
  39. DOCS_SITE_DEFAULT_PROMPT = """
  40. You are an expert AI assistant for developer support product. Your responses must always be rooted in the context provided for each query. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
  41. Here are some guidelines to follow:
  42. 1. Refrain from explicitly mentioning the context provided in your response.
  43. 2. The context should silently guide your answers without being directly acknowledged.
  44. 3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
  45. Context information:
  46. ----------------------
  47. $context
  48. ----------------------
  49. Query: $query
  50. Answer:
  51. """ # noqa:E501
  52. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  53. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  54. DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
  55. query_re = re.compile(r"\$\{*query\}*")
  56. context_re = re.compile(r"\$\{*context\}*")
  57. history_re = re.compile(r"\$\{*history\}*")
  58. @register_deserializable
  59. class BaseLlmConfig(BaseConfig):
  60. """
  61. Config for the `query` method.
  62. """
  63. def __init__(
  64. self,
  65. number_documents: int = 3,
  66. template: Optional[Template] = None,
  67. prompt: Optional[Template] = None,
  68. model: Optional[str] = None,
  69. temperature: float = 0,
  70. max_tokens: int = 1000,
  71. top_p: float = 1,
  72. stream: bool = False,
  73. online: bool = False,
  74. token_usage: bool = False,
  75. deployment_name: Optional[str] = None,
  76. system_prompt: Optional[str] = None,
  77. where: dict[str, Any] = None,
  78. query_type: Optional[str] = None,
  79. callbacks: Optional[list] = None,
  80. api_key: Optional[str] = None,
  81. base_url: Optional[str] = None,
  82. endpoint: Optional[str] = None,
  83. model_kwargs: Optional[dict[str, Any]] = None,
  84. http_client_proxies: Optional[Union[Dict, str]] = None,
  85. http_async_client_proxies: Optional[Union[Dict, str]] = None,
  86. local: Optional[bool] = False,
  87. default_headers: Optional[Mapping[str, str]] = None,
  88. api_version: Optional[str] = None,
  89. ):
  90. """
  91. Initializes a configuration class instance for the LLM.
  92. Takes the place of the former `QueryConfig` or `ChatConfig`.
  93. :param number_documents: Number of documents to pull from the database as
  94. context, defaults to 1
  95. :type number_documents: int, optional
  96. :param template: The `Template` instance to use as a template for
  97. prompt, defaults to None (deprecated)
  98. :type template: Optional[Template], optional
  99. :param prompt: The `Template` instance to use as a template for
  100. prompt, defaults to None
  101. :type prompt: Optional[Template], optional
  102. :param model: Controls the OpenAI model used, defaults to None
  103. :type model: Optional[str], optional
  104. :param temperature: Controls the randomness of the model's output.
  105. Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
  106. :type temperature: float, optional
  107. :param max_tokens: Controls how many tokens are generated, defaults to 1000
  108. :type max_tokens: int, optional
  109. :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
  110. defaults to 1
  111. :type top_p: float, optional
  112. :param stream: Control if response is streamed back to user, defaults to False
  113. :type stream: bool, optional
  114. :param online: Controls whether to use internet for answering query, defaults to False
  115. :type online: bool, optional
  116. :param token_usage: Controls whether to return token usage in response, defaults to False
  117. :type token_usage: bool, optional
  118. :param deployment_name: t.b.a., defaults to None
  119. :type deployment_name: Optional[str], optional
  120. :param system_prompt: System prompt string, defaults to None
  121. :type system_prompt: Optional[str], optional
  122. :param where: A dictionary of key-value pairs to filter the database results., defaults to None
  123. :type where: dict[str, Any], optional
  124. :param api_key: The api key of the custom endpoint, defaults to None
  125. :type api_key: Optional[str], optional
  126. :param endpoint: The api url of the custom endpoint, defaults to None
  127. :type endpoint: Optional[str], optional
  128. :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
  129. :type model_kwargs: Optional[Dict[str, Any]], optional
  130. :param callbacks: Langchain callback functions to use, defaults to None
  131. :type callbacks: Optional[list], optional
  132. :param query_type: The type of query to use, defaults to None
  133. :type query_type: Optional[str], optional
  134. :param http_client_proxies: The proxy server settings used to create self.http_client, defaults to None
  135. :type http_client_proxies: Optional[Dict | str], optional
  136. :param http_async_client_proxies: The proxy server settings for async calls used to create
  137. self.http_async_client, defaults to None
  138. :type http_async_client_proxies: Optional[Dict | str], optional
  139. :param local: If True, the model will be run locally, defaults to False (for huggingface provider)
  140. :type local: Optional[bool], optional
  141. :param default_headers: Set additional HTTP headers to be sent with requests to OpenAI
  142. :type default_headers: Optional[Mapping[str, str]], optional
  143. :raises ValueError: If the template is not valid as template should
  144. contain $context and $query (and optionally $history)
  145. :raises ValueError: Stream is not boolean
  146. """
  147. if template is not None:
  148. logger.warning(
  149. "The `template` argument is deprecated and will be removed in a future version. "
  150. + "Please use `prompt` instead."
  151. )
  152. if prompt is None:
  153. prompt = template
  154. if prompt is None:
  155. prompt = DEFAULT_PROMPT_TEMPLATE
  156. self.number_documents = number_documents
  157. self.temperature = temperature
  158. self.max_tokens = max_tokens
  159. self.model = model
  160. self.top_p = top_p
  161. self.online = online
  162. self.token_usage = token_usage
  163. self.deployment_name = deployment_name
  164. self.system_prompt = system_prompt
  165. self.query_type = query_type
  166. self.callbacks = callbacks
  167. self.api_key = api_key
  168. self.base_url = base_url
  169. self.endpoint = endpoint
  170. self.model_kwargs = model_kwargs
  171. self.http_client = httpx.Client(proxies=http_client_proxies) if http_client_proxies else None
  172. self.http_async_client = (
  173. httpx.AsyncClient(proxies=http_async_client_proxies) if http_async_client_proxies else None
  174. )
  175. self.local = local
  176. self.default_headers = default_headers
  177. self.online = online
  178. self.api_version = api_version
  179. if token_usage:
  180. f = open("model_prices_and_context_window.json")
  181. self.model_pricing_map = json.load(f)
  182. if isinstance(prompt, str):
  183. prompt = Template(prompt)
  184. if self.validate_prompt(prompt):
  185. self.prompt = prompt
  186. else:
  187. raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
  188. if not isinstance(stream, bool):
  189. raise ValueError("`stream` should be bool")
  190. self.stream = stream
  191. self.where = where
  192. @staticmethod
  193. def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
  194. """
  195. validate the prompt
  196. :param prompt: the prompt to validate
  197. :type prompt: Template
  198. :return: valid (true) or invalid (false)
  199. :rtype: Optional[re.Match[str]]
  200. """
  201. return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
  202. @staticmethod
  203. def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
  204. """
  205. validate the prompt with history
  206. :param prompt: the prompt to validate
  207. :type prompt: Template
  208. :return: valid (true) or invalid (false)
  209. :rtype: Optional[re.Match[str]]
  210. """
  211. return re.search(history_re, prompt.template)