base.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import logging
  2. import re
  3. from string import Template
  4. from typing import Any, Optional
  5. from embedchain.config.base_config import BaseConfig
  6. from embedchain.helpers.json_serializable import register_deserializable
  7. DEFAULT_PROMPT = """
  8. Use the following pieces of context to answer the query at the end.
  9. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  10. $context
  11. Query: $query
  12. Helpful Answer:
  13. """ # noqa:E501
  14. DEFAULT_PROMPT_WITH_HISTORY = """
  15. Use the following pieces of context to answer the query at the end.
  16. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  17. I will provide you with our conversation history.
  18. $context
  19. History:
  20. $history
  21. Query: $query
  22. Helpful Answer:
  23. """ # noqa:E501
  24. DOCS_SITE_DEFAULT_PROMPT = """
  25. Use the following pieces of context to answer the query at the end.
  26. If you don't know the answer, just say that you don't know, don't try to make up an answer. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
  27. $context
  28. Query: $query
  29. Helpful Answer:
  30. """ # noqa:E501
  31. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  32. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  33. DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
  34. query_re = re.compile(r"\$\{*query\}*")
  35. context_re = re.compile(r"\$\{*context\}*")
  36. history_re = re.compile(r"\$\{*history\}*")
  37. @register_deserializable
  38. class BaseLlmConfig(BaseConfig):
  39. """
  40. Config for the `query` method.
  41. """
  42. def __init__(
  43. self,
  44. number_documents: int = 3,
  45. template: Optional[Template] = None,
  46. prompt: Optional[Template] = None,
  47. model: Optional[str] = None,
  48. temperature: float = 0,
  49. max_tokens: int = 1000,
  50. top_p: float = 1,
  51. stream: bool = False,
  52. deployment_name: Optional[str] = None,
  53. system_prompt: Optional[str] = None,
  54. where: dict[str, Any] = None,
  55. query_type: Optional[str] = None,
  56. callbacks: Optional[list] = None,
  57. api_key: Optional[str] = None,
  58. endpoint: Optional[str] = None,
  59. model_kwargs: Optional[dict[str, Any]] = None,
  60. ):
  61. """
  62. Initializes a configuration class instance for the LLM.
  63. Takes the place of the former `QueryConfig` or `ChatConfig`.
  64. :param number_documents: Number of documents to pull from the database as
  65. context, defaults to 1
  66. :type number_documents: int, optional
  67. :param template: The `Template` instance to use as a template for
  68. prompt, defaults to None (deprecated)
  69. :type template: Optional[Template], optional
  70. :param prompt: The `Template` instance to use as a template for
  71. prompt, defaults to None
  72. :type prompt: Optional[Template], optional
  73. :param model: Controls the OpenAI model used, defaults to None
  74. :type model: Optional[str], optional
  75. :param temperature: Controls the randomness of the model's output.
  76. Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
  77. :type temperature: float, optional
  78. :param max_tokens: Controls how many tokens are generated, defaults to 1000
  79. :type max_tokens: int, optional
  80. :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
  81. defaults to 1
  82. :type top_p: float, optional
  83. :param stream: Control if response is streamed back to user, defaults to False
  84. :type stream: bool, optional
  85. :param deployment_name: t.b.a., defaults to None
  86. :type deployment_name: Optional[str], optional
  87. :param system_prompt: System prompt string, defaults to None
  88. :type system_prompt: Optional[str], optional
  89. :param where: A dictionary of key-value pairs to filter the database results., defaults to None
  90. :type where: dict[str, Any], optional
  91. :param api_key: The api key of the custom endpoint, defaults to None
  92. :type api_key: Optional[str], optional
  93. :param endpoint: The api url of the custom endpoint, defaults to None
  94. :type endpoint: Optional[str], optional
  95. :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
  96. :type model_kwargs: Optional[Dict[str, Any]], optional
  97. :param callbacks: Langchain callback functions to use, defaults to None
  98. :type callbacks: Optional[list], optional
  99. :param query_type: The type of query to use, defaults to None
  100. :type query_type: Optional[str], optional
  101. :raises ValueError: If the template is not valid as template should
  102. contain $context and $query (and optionally $history)
  103. :raises ValueError: Stream is not boolean
  104. """
  105. if template is not None:
  106. logging.warning(
  107. "The `template` argument is deprecated and will be removed in a future version. "
  108. + "Please use `prompt` instead."
  109. )
  110. if prompt is None:
  111. prompt = template
  112. if prompt is None:
  113. prompt = DEFAULT_PROMPT_TEMPLATE
  114. self.number_documents = number_documents
  115. self.temperature = temperature
  116. self.max_tokens = max_tokens
  117. self.model = model
  118. self.top_p = top_p
  119. self.deployment_name = deployment_name
  120. self.system_prompt = system_prompt
  121. self.query_type = query_type
  122. self.callbacks = callbacks
  123. self.api_key = api_key
  124. self.endpoint = endpoint
  125. self.model_kwargs = model_kwargs
  126. if isinstance(prompt, str):
  127. prompt = Template(prompt)
  128. if self.validate_prompt(prompt):
  129. self.prompt = prompt
  130. else:
  131. raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
  132. if not isinstance(stream, bool):
  133. raise ValueError("`stream` should be bool")
  134. self.stream = stream
  135. self.where = where
  136. @staticmethod
  137. def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
  138. """
  139. validate the prompt
  140. :param prompt: the prompt to validate
  141. :type prompt: Template
  142. :return: valid (true) or invalid (false)
  143. :rtype: Optional[re.Match[str]]
  144. """
  145. return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
  146. @staticmethod
  147. def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
  148. """
  149. validate the prompt with history
  150. :param prompt: the prompt to validate
  151. :type prompt: Template
  152. :return: valid (true) or invalid (false)
  153. :rtype: Optional[re.Match[str]]
  154. """
  155. return re.search(history_re, prompt.template)