QueryConfig.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import re
  2. from string import Template
  3. from embedchain.config.BaseConfig import BaseConfig
  4. DEFAULT_PROMPT = """
  5. Use the following pieces of context to answer the query at the end.
  6. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  7. $context
  8. Query: $query
  9. Helpful Answer:
  10. """ # noqa:E501
  11. DEFAULT_PROMPT_WITH_HISTORY = """
  12. Use the following pieces of context to answer the query at the end.
  13. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  14. I will provide you with our conversation history.
  15. $context
  16. History: $history
  17. Query: $query
  18. Helpful Answer:
  19. """ # noqa:E501
  20. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  21. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  22. query_re = re.compile(r"\$\{*query\}*")
  23. context_re = re.compile(r"\$\{*context\}*")
  24. history_re = re.compile(r"\$\{*history\}*")
  25. class QueryConfig(BaseConfig):
  26. """
  27. Config for the `query` method.
  28. """
  29. def __init__(
  30. self,
  31. number_documents=None,
  32. template: Template = None,
  33. model=None,
  34. temperature=None,
  35. max_tokens=None,
  36. top_p=None,
  37. history=None,
  38. stream: bool = False,
  39. ):
  40. """
  41. Initializes the QueryConfig instance.
  42. :param number_documents: Number of documents to pull from the database as
  43. context.
  44. :param template: Optional. The `Template` instance to use as a template for
  45. prompt.
  46. :param model: Optional. Controls the OpenAI model used.
  47. :param temperature: Optional. Controls the randomness of the model's output.
  48. Higher values (closer to 1) make output more random, lower values make it more
  49. deterministic.
  50. :param max_tokens: Optional. Controls how many tokens are generated.
  51. :param top_p: Optional. Controls the diversity of words. Higher values
  52. (closer to 1) make word selection more diverse, lower values make words less
  53. diverse.
  54. :param history: Optional. A list of strings to consider as history.
  55. :param stream: Optional. Control if response is streamed back to user
  56. :raises ValueError: If the template is not valid as template should
  57. contain $context and $query (and optionally $history).
  58. """
  59. if number_documents is None:
  60. self.number_documents = 1
  61. else:
  62. self.number_documents = number_documents
  63. if not history:
  64. self.history = None
  65. else:
  66. if len(history) == 0:
  67. self.history = None
  68. else:
  69. self.history = history
  70. if template is None:
  71. if self.history is None:
  72. template = DEFAULT_PROMPT_TEMPLATE
  73. else:
  74. template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
  75. self.temperature = temperature if temperature else 0
  76. self.max_tokens = max_tokens if max_tokens else 1000
  77. self.model = model if model else "gpt-3.5-turbo-0613"
  78. self.top_p = top_p if top_p else 1
  79. if self.validate_template(template):
  80. self.template = template
  81. else:
  82. if self.history is None:
  83. raise ValueError("`template` should have `query` and `context` keys")
  84. else:
  85. raise ValueError(
  86. "`template` should have `query`, `context` and `history` keys"
  87. )
  88. if not isinstance(stream, bool):
  89. raise ValueError("`stream` should be bool")
  90. self.stream = stream
  91. def validate_template(self, template: Template):
  92. """
  93. validate the template
  94. :param template: the template to validate
  95. :return: Boolean, valid (true) or invalid (false)
  96. """
  97. if self.history is None:
  98. return re.search(query_re, template.template) and re.search(
  99. context_re, template.template
  100. )
  101. else:
  102. return (
  103. re.search(query_re, template.template)
  104. and re.search(context_re, template.template)
  105. and re.search(history_re, template.template)
  106. )