base.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import re
  2. from string import Template
  3. from typing import Any, Dict, List, Optional
  4. from embedchain.config.base_config import BaseConfig
  5. from embedchain.helpers.json_serializable import register_deserializable
  6. DEFAULT_PROMPT = """
  7. Use the following pieces of context to answer the query at the end.
  8. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  9. $context
  10. Query: $query
  11. Helpful Answer:
  12. """ # noqa:E501
  13. DEFAULT_PROMPT_WITH_HISTORY = """
  14. Use the following pieces of context to answer the query at the end.
  15. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  16. I will provide you with our conversation history.
  17. $context
  18. History: $history
  19. Query: $query
  20. Helpful Answer:
  21. """ # noqa:E501
  22. DOCS_SITE_DEFAULT_PROMPT = """
  23. Use the following pieces of context to answer the query at the end.
  24. If you don't know the answer, just say that you don't know, don't try to make up an answer. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
  25. $context
  26. Query: $query
  27. Helpful Answer:
  28. """ # noqa:E501
  29. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  30. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  31. DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
  32. query_re = re.compile(r"\$\{*query\}*")
  33. context_re = re.compile(r"\$\{*context\}*")
  34. history_re = re.compile(r"\$\{*history\}*")
  35. @register_deserializable
  36. class BaseLlmConfig(BaseConfig):
  37. """
  38. Config for the `query` method.
  39. """
  40. def __init__(
  41. self,
  42. number_documents: int = 3,
  43. template: Optional[Template] = None,
  44. model: Optional[str] = None,
  45. temperature: float = 0,
  46. max_tokens: int = 1000,
  47. top_p: float = 1,
  48. stream: bool = False,
  49. deployment_name: Optional[str] = None,
  50. system_prompt: Optional[str] = None,
  51. where: Dict[str, Any] = None,
  52. query_type: Optional[str] = None,
  53. callbacks: Optional[List] = None,
  54. api_key: Optional[str] = None,
  55. ):
  56. """
  57. Initializes a configuration class instance for the LLM.
  58. Takes the place of the former `QueryConfig` or `ChatConfig`.
  59. :param number_documents: Number of documents to pull from the database as
  60. context, defaults to 1
  61. :type number_documents: int, optional
  62. :param template: The `Template` instance to use as a template for
  63. prompt, defaults to None
  64. :type template: Optional[Template], optional
  65. :param model: Controls the OpenAI model used, defaults to None
  66. :type model: Optional[str], optional
  67. :param temperature: Controls the randomness of the model's output.
  68. Higher values (closer to 1) make output more random, lower values make it more deterministic, defaults to 0
  69. :type temperature: float, optional
  70. :param max_tokens: Controls how many tokens are generated, defaults to 1000
  71. :type max_tokens: int, optional
  72. :param top_p: Controls the diversity of words. Higher values (closer to 1) make word selection more diverse,
  73. defaults to 1
  74. :type top_p: float, optional
  75. :param stream: Control if response is streamed back to user, defaults to False
  76. :type stream: bool, optional
  77. :param deployment_name: t.b.a., defaults to None
  78. :type deployment_name: Optional[str], optional
  79. :param system_prompt: System prompt string, defaults to None
  80. :type system_prompt: Optional[str], optional
  81. :param where: A dictionary of key-value pairs to filter the database results., defaults to None
  82. :type where: Dict[str, Any], optional
  83. :param callbacks: Langchain callback functions to use, defaults to None
  84. :type callbacks: Optional[List], optional
  85. :raises ValueError: If the template is not valid as template should
  86. contain $context and $query (and optionally $history)
  87. :raises ValueError: Stream is not boolean
  88. """
  89. if template is None:
  90. template = DEFAULT_PROMPT_TEMPLATE
  91. self.number_documents = number_documents
  92. self.temperature = temperature
  93. self.max_tokens = max_tokens
  94. self.model = model
  95. self.top_p = top_p
  96. self.deployment_name = deployment_name
  97. self.system_prompt = system_prompt
  98. self.query_type = query_type
  99. self.callbacks = callbacks
  100. self.api_key = api_key
  101. if type(template) is str:
  102. template = Template(template)
  103. if self.validate_template(template):
  104. self.template = template
  105. else:
  106. raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
  107. if not isinstance(stream, bool):
  108. raise ValueError("`stream` should be bool")
  109. self.stream = stream
  110. self.where = where
  111. def validate_template(self, template: Template) -> bool:
  112. """
  113. validate the template
  114. :param template: the template to validate
  115. :type template: Template
  116. :return: valid (true) or invalid (false)
  117. :rtype: bool
  118. """
  119. return re.search(query_re, template.template) and re.search(context_re, template.template)
  120. def _validate_template_history(self, template: Template) -> bool:
  121. """
  122. validate the template with history
  123. :param template: the template to validate
  124. :type template: Template
  125. :return: valid (true) or invalid (false)
  126. :rtype: bool
  127. """
  128. return re.search(history_re, template.template)