PersonApp.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from string import Template
  2. from embedchain.apps.App import App
  3. from embedchain.apps.OpenSourceApp import OpenSourceApp
  4. from embedchain.config import ChatConfig, QueryConfig
  5. from embedchain.config.apps.BaseAppConfig import BaseAppConfig
  6. from embedchain.config.QueryConfig import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY
  7. from embedchain.helper_classes.json_serializable import register_deserializable
  8. @register_deserializable
  9. class EmbedChainPersonApp:
  10. """
  11. Base class to create a person bot.
  12. This bot behaves and speaks like a person.
  13. :param person: name of the person, better if its a well known person.
  14. :param config: BaseAppConfig instance to load as configuration.
  15. """
  16. def __init__(self, person, config: BaseAppConfig = None):
  17. self.person = person
  18. self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
  19. super().__init__(config)
  20. def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None):
  21. """
  22. This method checks if the config object contains a prompt template
  23. if yes it adds the person prompt to it and return the updated config
  24. else it creates a config object with the default prompt added to the person prompt
  25. :param default_prompt: it is the default prompt for query or chat methods
  26. :param config: Optional. The `ChatConfig` instance to use as
  27. configuration options.
  28. """
  29. template = Template(self.person_prompt + " " + default_prompt)
  30. if config:
  31. if config.template:
  32. # Add person prompt to custom user template
  33. config.template = Template(self.person_prompt + " " + config.template.template)
  34. else:
  35. # If no user template is present, use person prompt with the default template
  36. config.template = template
  37. else:
  38. # if no config is present at all, initialize the config with person prompt and default template
  39. config = QueryConfig(
  40. template=template,
  41. )
  42. return config
  43. @register_deserializable
  44. class PersonApp(EmbedChainPersonApp, App):
  45. """
  46. The Person app.
  47. Extends functionality from EmbedChainPersonApp and App
  48. """
  49. def query(self, input_query, config: QueryConfig = None, dry_run=False):
  50. config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
  51. return super().query(input_query, config, dry_run, where=None)
  52. def chat(self, input_query, config: ChatConfig = None, dry_run=False, where=None):
  53. config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
  54. return super().chat(input_query, config, dry_run, where)
  55. @register_deserializable
  56. class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
  57. """
  58. The Person app.
  59. Extends functionality from EmbedChainPersonApp and OpenSourceApp
  60. """
  61. def query(self, input_query, config: QueryConfig = None, dry_run=False):
  62. config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
  63. return super().query(input_query, config, dry_run)
  64. def chat(self, input_query, config: ChatConfig = None, dry_run=False):
  65. config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
  66. return super().chat(input_query, config, dry_run)