base_llm_config.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. import re
  2. from string import Template
  3. from typing import Optional
  4. from embedchain.config.BaseConfig import BaseConfig
  5. from embedchain.helper_classes.json_serializable import register_deserializable
  6. DEFAULT_PROMPT = """
  7. Use the following pieces of context to answer the query at the end.
  8. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  9. $context
  10. Query: $query
  11. Helpful Answer:
  12. """ # noqa:E501
  13. DEFAULT_PROMPT_WITH_HISTORY = """
  14. Use the following pieces of context to answer the query at the end.
  15. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  16. I will provide you with our conversation history.
  17. $context
  18. History: $history
  19. Query: $query
  20. Helpful Answer:
  21. """ # noqa:E501
  22. DOCS_SITE_DEFAULT_PROMPT = """
  23. Use the following pieces of context to answer the query at the end.
  24. If you don't know the answer, just say that you don't know, don't try to make up an answer. Wherever possible, give complete code snippet. Dont make up any code snippet on your own.
  25. $context
  26. Query: $query
  27. Helpful Answer:
  28. """ # noqa:E501
  29. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  30. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  31. DOCS_SITE_PROMPT_TEMPLATE = Template(DOCS_SITE_DEFAULT_PROMPT)
  32. query_re = re.compile(r"\$\{*query\}*")
  33. context_re = re.compile(r"\$\{*context\}*")
  34. history_re = re.compile(r"\$\{*history\}*")
  35. @register_deserializable
  36. class BaseLlmConfig(BaseConfig):
  37. """
  38. Config for the `query` method.
  39. """
  40. def __init__(
  41. self,
  42. number_documents=None,
  43. template: Template = None,
  44. model=None,
  45. temperature=None,
  46. max_tokens=None,
  47. top_p=None,
  48. stream: bool = False,
  49. deployment_name=None,
  50. system_prompt: Optional[str] = None,
  51. where=None,
  52. ):
  53. """
  54. Initializes the QueryConfig instance.
  55. :param number_documents: Number of documents to pull from the database as
  56. context.
  57. :param template: Optional. The `Template` instance to use as a template for
  58. prompt.
  59. :param model: Optional. Controls the OpenAI model used.
  60. :param temperature: Optional. Controls the randomness of the model's output.
  61. Higher values (closer to 1) make output more random, lower values make it more
  62. deterministic.
  63. :param max_tokens: Optional. Controls how many tokens are generated.
  64. :param top_p: Optional. Controls the diversity of words. Higher values
  65. (closer to 1) make word selection more diverse, lower values make words less
  66. diverse.
  67. :param stream: Optional. Control if response is streamed back to user
  68. :param deployment_name: t.b.a.
  69. :param system_prompt: Optional. System prompt string.
  70. :param where: Optional. A dictionary of key-value pairs to filter the database results.
  71. :raises ValueError: If the template is not valid as template should
  72. contain $context and $query (and optionally $history).
  73. """
  74. if number_documents is None:
  75. self.number_documents = 1
  76. else:
  77. self.number_documents = number_documents
  78. if template is None:
  79. template = DEFAULT_PROMPT_TEMPLATE
  80. self.temperature = temperature if temperature else 0
  81. self.max_tokens = max_tokens if max_tokens else 1000
  82. self.model = model
  83. self.top_p = top_p if top_p else 1
  84. self.deployment_name = deployment_name
  85. self.system_prompt = system_prompt
  86. if self.validate_template(template):
  87. self.template = template
  88. else:
  89. raise ValueError("`template` should have `query` and `context` keys and potentially `history` (if used).")
  90. if not isinstance(stream, bool):
  91. raise ValueError("`stream` should be bool")
  92. self.stream = stream
  93. self.where = where
  94. def validate_template(self, template: Template):
  95. """
  96. validate the template
  97. :param template: the template to validate
  98. :return: Boolean, valid (true) or invalid (false)
  99. """
  100. return re.search(query_re, template.template) and re.search(context_re, template.template)
  101. def _validate_template_history(self, template: Template):
  102. """
  103. validate the history template for history
  104. :param template: the template to validate
  105. :return: Boolean, valid (true) or invalid (false)
  106. """
  107. return re.search(history_re, template.template)