QueryConfig.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from embedchain.config.BaseConfig import BaseConfig
  2. from string import Template
  3. import re
  4. DEFAULT_PROMPT = """
  5. Use the following pieces of context to answer the query at the end.
  6. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  7. $context
  8. Query: $query
  9. Helpful Answer:
  10. """
  11. DEFAULT_PROMPT_WITH_HISTORY = """
  12. Use the following pieces of context to answer the query at the end.
  13. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  14. I will provide you with our conversation history.
  15. $context
  16. History: $history
  17. Query: $query
  18. Helpful Answer:
  19. """
  20. DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
  21. DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
  22. query_re = re.compile(r"\$\{*query\}*")
  23. context_re = re.compile(r"\$\{*context\}*")
  24. history_re = re.compile(r"\$\{*history\}*")
  25. class QueryConfig(BaseConfig):
  26. """
  27. Config for the `query` method.
  28. """
  29. def __init__(self, template: Template = None, history = None, stream: bool = False):
  30. """
  31. Initializes the QueryConfig instance.
  32. :param template: Optional. The `Template` instance to use as a template for prompt.
  33. :param history: Optional. A list of strings to consider as history.
  34. :param stream: Optional. Control if response is streamed back to the user
  35. :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history).
  36. """
  37. if not history:
  38. self.history = None
  39. else:
  40. if len(history) == 0:
  41. self.history = None
  42. else:
  43. self.history = history
  44. if template is None:
  45. if self.history is None:
  46. template = DEFAULT_PROMPT_TEMPLATE
  47. else:
  48. template = DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE
  49. if self.validate_template(template):
  50. self.template = template
  51. else:
  52. if self.history is None:
  53. raise ValueError("`template` should have `query` and `context` keys")
  54. else:
  55. raise ValueError("`template` should have `query`, `context` and `history` keys")
  56. if not isinstance(stream, bool):
  57. raise ValueError("`stream` should be bool")
  58. self.stream = stream
  59. def validate_template(self, template: Template):
  60. """
  61. validate the template
  62. :param template: the template to validate
  63. :return: Boolean, valid (true) or invalid (false)
  64. """
  65. if self.history is None:
  66. return (re.search(query_re, template.template) \
  67. and re.search(context_re, template.template))
  68. else:
  69. return (re.search(query_re, template.template) \
  70. and re.search(context_re, template.template)
  71. and re.search(history_re, template.template))