base.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import logging
  2. import re
  3. from string import Template
  4. from typing import Any, Optional
  5. from embedchain.config.base_config import BaseConfig
  6. from embedchain.helpers.json_serializable import register_deserializable
  7. DEFAULT_PROMPT = """
  8. You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow:
  9. 1. Refrain from explicitly mentioning the context provided in your response.
  10. 2. The context should silently guide your answers without being directly acknowledged.
  11. 3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
  12. Context information:
  13. ----------------------
  14. $context
  15. ----------------------
  16. Query: $query
  17. Answer:
  18. """ # noqa:E501
  19. DEFAULT_PROMPT_WITH_HISTORY = """
  20. You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. You are also provided with the conversation history with the user. Make sure to use relevant context from conversation history as needed.
  21. Here are some guidelines to follow:
  22. 1. Refrain from explicitly mentioning the context provided in your response.
  23. 2. The context should silently guide your answers without being directly acknowledged.
  24. 3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
  25. Context information:
  26. ----------------------
  27. $context
  28. ----------------------
  29. Conversation history:
  30. ----------------------
  31. $history
  32. ----------------------
  33. Query: $query
  34. Answer:
  35. """ # noqa:E501
  36. DOCS_SITE_DEFAULT_PROMPT = """
  37. You are an expert AI assistant for developer support product. Your responses must always be rooted in the context provided for each query. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
  38. Here are some guidelines to follow:
  39. 1. Refrain from explicitly mentioning the context provided in your response.
  40. 2. The context should silently guide your answers without being directly acknowledged.
  41. 3. Do not use phrases such as 'According to the context provided', 'Based on the context, ...' etc.
  42. Context information:
  43. ----------------------
  44. $context
  45. ----------------------
  46. Query: $query
  47. Answer:
  48. """ # noqa:E501
  49. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  50. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  51. DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
  52. query_re = re.compile(r"\$\{*query\}*")
  53. context_re = re.compile(r"\$\{*context\}*")
  54. history_re = re.compile(r"\$\{*history\}*")
  55. @register_deserializable
  56. class BaseLlmConfig(BaseConfig):
  57. """
  58. Config for the `query` method.
  59. """
  60. def __init__(
  61. self,
  62. number_documents: int = 3,
  63. template: Optional[Template] = None,
  64. prompt: Optional[Template] = None,
  65. model: Optional[str] = None,
  66. temperature: float = 0,
  67. max_tokens: int = 1000,
  68. top_p: float = 1,
  69. stream: bool = False,
  70. deployment_name: Optional[str] = None,
  71. system_prompt: Optional[str] = None,
  72. where: dict[str, Any] = None,
  73. query_type: Optional[str] = None,
  74. callbacks: Optional[list] = None,
  75. api_key: Optional[str] = None,
  76. endpoint: Optional[str] = None,
  77. model_kwargs: Optional[dict[str, Any]] = None,
  78. ):
  79. """
  80. Initializes a configuration class instance for the LLM.
  81. Takes the place of the former `QueryConfig` or `ChatConfig`.
  82. :param number_documents: Number of documents to pull from the database as
  83. context, defaults to 1
  84. :type number_documents: int, optional
  85. :param template: The `Template` instance to use as a template for
  86. prompt, defaults to None (deprecated)
  87. :type template: Optional[Template], optional
  88. :param prompt: The `Template` instance to use as a template for
  89. prompt, defaults to None
  90. :type prompt: Optional[Template], optional
  91. :param model: Controls the OpenAI model used, defaults to None
  92. :type model: Optional[str], optional
  93. :param temperature: Controls the randomness of the model's output.
  94. Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
  95. :type temperature: float, optional
  96. :param max_tokens: Controls how many tokens are generated, defaults to 1000
  97. :type max_tokens: int, optional
  98. :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
  99. defaults to 1
  100. :type top_p: float, optional
  101. :param stream: Control if response is streamed back to user, defaults to False
  102. :type stream: bool, optional
  103. :param deployment_name: t.b.a., defaults to None
  104. :type deployment_name: Optional[str], optional
  105. :param system_prompt: System prompt string, defaults to None
  106. :type system_prompt: Optional[str], optional
  107. :param where: A dictionary of key-value pairs to filter the database results., defaults to None
  108. :type where: dict[str, Any], optional
  109. :param api_key: The api key of the custom endpoint, defaults to None
  110. :type api_key: Optional[str], optional
  111. :param endpoint: The api url of the custom endpoint, defaults to None
  112. :type endpoint: Optional[str], optional
  113. :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
  114. :type model_kwargs: Optional[Dict[str, Any]], optional
  115. :param callbacks: Langchain callback functions to use, defaults to None
  116. :type callbacks: Optional[list], optional
  117. :param query_type: The type of query to use, defaults to None
  118. :type query_type: Optional[str], optional
  119. :raises ValueError: If the template is not valid as template should
  120. contain $context and $query (and optionally $history)
  121. :raises ValueError: Stream is not boolean
  122. """
  123. if template is not None:
  124. logging.warning(
  125. "The `template` argument is deprecated and will be removed in a future version. "
  126. + "Please use `prompt` instead."
  127. )
  128. if prompt is None:
  129. prompt = template
  130. if prompt is None:
  131. prompt = DEFAULT_PROMPT_TEMPLATE
  132. self.number_documents = number_documents
  133. self.temperature = temperature
  134. self.max_tokens = max_tokens
  135. self.model = model
  136. self.top_p = top_p
  137. self.deployment_name = deployment_name
  138. self.system_prompt = system_prompt
  139. self.query_type = query_type
  140. self.callbacks = callbacks
  141. self.api_key = api_key
  142. self.endpoint = endpoint
  143. self.model_kwargs = model_kwargs
  144. if isinstance(prompt, str):
  145. prompt = Template(prompt)
  146. if self.validate_prompt(prompt):
  147. self.prompt = prompt
  148. else:
  149. raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
  150. if not isinstance(stream, bool):
  151. raise ValueError("`stream` should be bool")
  152. self.stream = stream
  153. self.where = where
  154. @staticmethod
  155. def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
  156. """
  157. validate the prompt
  158. :param prompt: the prompt to validate
  159. :type prompt: Template
  160. :return: valid (true) or invalid (false)
  161. :rtype: Optional[re.Match[str]]
  162. """
  163. return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
  164. @staticmethod
  165. def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
  166. """
  167. validate the prompt with history
  168. :param prompt: the prompt to validate
  169. :type prompt: Template
  170. :return: valid (true) or invalid (false)
  171. :rtype: Optional[re.Match[str]]
  172. """
  173. return re.search(history_re, prompt.template)