PersonApp.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from string import Template
  2. from embedchain.apps.App import App
  3. from embedchain.apps.OpenSourceApp import OpenSourceApp
  4. from embedchain.config import BaseLlmConfig
  5. from embedchain.config.apps.BaseAppConfig import BaseAppConfig
  6. from embedchain.config.llm.base_llm_config import (DEFAULT_PROMPT,
  7. DEFAULT_PROMPT_WITH_HISTORY)
  8. from embedchain.helper_classes.json_serializable import register_deserializable
  9. @register_deserializable
  10. class EmbedChainPersonApp:
  11. """
  12. Base class to create a person bot.
  13. This bot behaves and speaks like a person.
  14. :param person: name of the person, better if its a well known person.
  15. :param config: BaseAppConfig instance to load as configuration.
  16. """
  17. def __init__(self, person, config: BaseAppConfig = None):
  18. self.person = person
  19. self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style." # noqa:E501
  20. super().__init__(config)
  21. def add_person_template_to_config(self, default_prompt: str, config: BaseLlmConfig = None):
  22. """
  23. This method checks if the config object contains a prompt template
  24. if yes it adds the person prompt to it and return the updated config
  25. else it creates a config object with the default prompt added to the person prompt
  26. :param default_prompt: it is the default prompt for query or chat methods
  27. :param config: Optional. The `ChatConfig` instance to use as
  28. configuration options.
  29. """
  30. template = Template(self.person_prompt + " " + default_prompt)
  31. if config:
  32. if config.template:
  33. # Add person prompt to custom user template
  34. config.template = Template(self.person_prompt + " " + config.template.template)
  35. else:
  36. # If no user template is present, use person prompt with the default template
  37. config.template = template
  38. else:
  39. # if no config is present at all, initialize the config with person prompt and default template
  40. config = BaseLlmConfig(
  41. template=template,
  42. )
  43. return config
  44. @register_deserializable
  45. class PersonApp(EmbedChainPersonApp, App):
  46. """
  47. The Person app.
  48. Extends functionality from EmbedChainPersonApp and App
  49. """
  50. def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
  51. config = self.add_person_template_to_config(DEFAULT_PROMPT, config, where=None)
  52. return super().query(input_query, config, dry_run, where=None)
  53. def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False, where=None):
  54. config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
  55. return super().chat(input_query, config, dry_run, where)
  56. @register_deserializable
  57. class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
  58. """
  59. The Person app.
  60. Extends functionality from EmbedChainPersonApp and OpenSourceApp
  61. """
  62. def query(self, input_query, config: BaseLlmConfig = None, dry_run=False):
  63. config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
  64. return super().query(input_query, config, dry_run)
  65. def chat(self, input_query, config: BaseLlmConfig = None, dry_run=False):
  66. config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
  67. return super().chat(input_query, config, dry_run)