base.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import logging
  2. import re
  3. from string import Template
  4. from typing import Any, Dict, List, Optional
  5. from embedchain.config.base_config import BaseConfig
  6. from embedchain.helpers.json_serializable import register_deserializable
  7. DEFAULT_PROMPT = """
  8. Use the following pieces of context to answer the query at the end.
  9. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  10. $context
  11. Query: $query
  12. Helpful Answer:
  13. """ # noqa:E501
  14. DEFAULT_PROMPT_WITH_HISTORY = """
  15. Use the following pieces of context to answer the query at the end.
  16. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  17. I will provide you with our conversation history.
  18. $context
  19. History: $history
  20. Query: $query
  21. Helpful Answer:
  22. """ # noqa:E501
  23. DOCS_SITE_DEFAULT_PROMPT = """
  24. Use the following pieces of context to answer the query at the end.
  25. If you don't know the answer, just say that you don't know, don't try to make up an answer. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
  26. $context
  27. Query: $query
  28. Helpful Answer:
  29. """ # noqa:E501
  30. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  31. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  32. DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
  33. query_re = re.compile(r"\$\{*query\}*")
  34. context_re = re.compile(r"\$\{*context\}*")
  35. history_re = re.compile(r"\$\{*history\}*")
  36. @register_deserializable
  37. class BaseLlmConfig(BaseConfig):
  38. """
  39. Config for the `query` method.
  40. """
  41. def __init__(
  42. self,
  43. number_documents: int = 3,
  44. template: Optional[Template] = None,
  45. prompt: Optional[Template] = None,
  46. model: Optional[str] = None,
  47. temperature: float = 0,
  48. max_tokens: int = 1000,
  49. top_p: float = 1,
  50. stream: bool = False,
  51. deployment_name: Optional[str] = None,
  52. system_prompt: Optional[str] = None,
  53. where: Dict[str, Any] = None,
  54. query_type: Optional[str] = None,
  55. callbacks: Optional[List] = None,
  56. api_key: Optional[str] = None,
  57. ):
  58. """
  59. Initializes a configuration class instance for the LLM.
  60. Takes the place of the former `QueryConfig` or `ChatConfig`.
  61. :param number_documents: Number of documents to pull from the database as
  62. context, defaults to 1
  63. :type number_documents: int, optional
  64. :param template: The `Template` instance to use as a template for
  65. prompt, defaults to None (deprecated)
  66. :type template: Optional[Template], optional
  67. :param prompt: The `Template` instance to use as a template for
  68. prompt, defaults to None
  69. :type prompt: Optional[Template], optional
  70. :param model: Controls the OpenAI model used, defaults to None
  71. :type model: Optional[str], optional
  72. :param temperature: Controls the randomness of the model's output.
  73. Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
  74. :type temperature: float, optional
  75. :param max_tokens: Controls how many tokens are generated, defaults to 1000
  76. :type max_tokens: int, optional
  77. :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
  78. defaults to 1
  79. :type top_p: float, optional
  80. :param stream: Control if response is streamed back to user, defaults to False
  81. :type stream: bool, optional
  82. :param deployment_name: t.b.a., defaults to None
  83. :type deployment_name: Optional[str], optional
  84. :param system_prompt: System prompt string, defaults to None
  85. :type system_prompt: Optional[str], optional
  86. :param where: A dictionary of key-value pairs to filter the database results., defaults to None
  87. :type where: Dict[str, Any], optional
  88. :param callbacks: Langchain callback functions to use, defaults to None
  89. :type callbacks: Optional[List], optional
  90. :raises ValueError: If the template is not valid as template should
  91. contain $context and $query (and optionally $history)
  92. :raises ValueError: Stream is not boolean
  93. """
  94. if template is not None:
  95. logging.warning(
  96. "The `template` argument is deprecated and will be removed in a future version. "
  97. + "Please use `prompt` instead."
  98. )
  99. if prompt is None:
  100. prompt = template
  101. if prompt is None:
  102. prompt = DEFAULT_PROMPT_TEMPLATE
  103. self.number_documents = number_documents
  104. self.temperature = temperature
  105. self.max_tokens = max_tokens
  106. self.model = model
  107. self.top_p = top_p
  108. self.deployment_name = deployment_name
  109. self.system_prompt = system_prompt
  110. self.query_type = query_type
  111. self.callbacks = callbacks
  112. self.api_key = api_key
  113. if type(prompt) is str:
  114. prompt = Template(prompt)
  115. if self.validate_prompt(prompt):
  116. self.prompt = prompt
  117. else:
  118. raise ValueError("The 'prompt' should have 'query' and 'context' keys and potentially 'history' (if used).")
  119. if not isinstance(stream, bool):
  120. raise ValueError("`stream` should be bool")
  121. self.stream = stream
  122. self.where = where
  123. def validate_prompt(self, prompt: Template) -> bool:
  124. """
  125. validate the prompt
  126. :param prompt: the prompt to validate
  127. :type prompt: Template
  128. :return: valid (true) or invalid (false)
  129. :rtype: bool
  130. """
  131. return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
  132. def _validate_prompt_history(self, prompt: Template) -> bool:
  133. """
  134. validate the prompt with history
  135. :param prompt: the prompt to validate
  136. :type prompt: Template
  137. :return: valid (true) or invalid (false)
  138. :rtype: bool
  139. """
  140. return re.search(history_re, prompt.template)