Jelajahi Sumber

fix: Personapp not working with config (#368)

aaishikdutta 2 tahun lalu
induk
melakukan
cce6d5ddab
1 mengubah file dengan 35 tambahan dan 20 penghapusan
  1. 35 20
      embedchain/apps/PersonApp.py

+ 35 - 20
embedchain/apps/PersonApp.py

@@ -22,6 +22,33 @@ class EmbedChainPersonApp:
         self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style."  # noqa:E501
         super().__init__(config)
 
+    def add_person_template_to_config(self, default_prompt: str, config: ChatConfig = None):
+        """
+        This method checks if the config object contains a prompt template
+        if yes it adds the person prompt to it and return the updated config
+        else it creates a config object with the default prompt added to the person prompt
+
+        :param default_prompt: it is the default prompt for query or chat methods
+        :param config: Optional. The `ChatConfig` instance to use as
+        configuration options.
+        """
+        template = Template(self.person_prompt + " " + default_prompt)
+
+        if config:
+            if config.template:
+                # Add person prompt to custom user template
+                config.template = Template(self.person_prompt + " " + config.template.template)
+            else:
+                # If no user template is present, use person prompt with the default template
+                config.template = template
+        else:
+            # if no config is present at all, initialize the config with person prompt and default template
+            config = QueryConfig(
+                template=template,
+            )
+
+        return config
+
 
 class PersonApp(EmbedChainPersonApp, App):
     """
@@ -30,18 +57,12 @@ class PersonApp(EmbedChainPersonApp, App):
     """
 
     def query(self, input_query, config: QueryConfig = None, dry_run=False):
-        self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
-        query_config = QueryConfig(
-            template=self.template,
-        )
-        return super().query(input_query, query_config, dry_run)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
+        return super().query(input_query, config, dry_run)
 
     def chat(self, input_query, config: ChatConfig = None, dry_run=False):
-        self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
-        chat_config = ChatConfig(
-            template=self.template,
-        )
-        return super().chat(input_query, chat_config, dry_run)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
+        return super().chat(input_query, config, dry_run)
 
 
 class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
@@ -51,15 +72,9 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
     """
 
     def query(self, input_query, config: QueryConfig = None, dry_run=False):
-        self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
-        query_config = QueryConfig(
-            template=self.template,
-        )
-        return super().query(input_query, query_config, dry_run)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT, config)
+        return super().query(input_query, config, dry_run)
 
     def chat(self, input_query, config: ChatConfig = None, dry_run=False):
-        self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT_WITH_HISTORY)
-        chat_config = ChatConfig(
-            template=self.template,
-        )
-        return super().chat(input_query, chat_config, dry_run)
+        config = self.add_person_template_to_config(DEFAULT_PROMPT_WITH_HISTORY, config)
+        return super().chat(input_query, config, dry_run)