Browse Source

fix: use template from tempory `LlmConfig` (#590)

cachho 1 năm trước cách đây
mục cha
commit
0f9a10c598
1 tập tin đã thay đổi với 66 bổ sung46 xóa
  1. 66 46
      embedchain/llm/base.py

+ 66 - 46
embedchain/llm/base.py

@@ -174,27 +174,37 @@ class BaseLlm(JSONSerializable):
         :return: The answer to the query or the dry run result
         :rtype: str
         """
-        query_config = config or self.config
-
-        if self.is_docs_site_instance:
-            query_config.template = DOCS_SITE_PROMPT_TEMPLATE
-            query_config.number_documents = 5
-        k = {}
-        if self.online:
-            k["web_search_result"] = self.access_search_and_get_results(input_query)
-        prompt = self.generate_prompt(input_query, contexts, **k)
-        logging.info(f"Prompt: {prompt}")
-
-        if dry_run:
-            return prompt
-
-        answer = self.get_answer_from_llm(prompt)
-
-        if isinstance(answer, str):
-            logging.info(f"Answer: {answer}")
-            return answer
-        else:
-            return self._stream_query_response(answer)
+        try:
+            if config:
+                # A config instance passed to this method will only be applied temporarily, for one call.
+                # So we will save the previous config and restore it at the end of the execution.
+                # For this we use the serializer.
+                prev_config = self.config.serialize()
+                self.config = config
+
+            if self.is_docs_site_instance:
+                self.config.template = DOCS_SITE_PROMPT_TEMPLATE
+                self.config.number_documents = 5
+            k = {}
+            if self.online:
+                k["web_search_result"] = self.access_search_and_get_results(input_query)
+            prompt = self.generate_prompt(input_query, contexts, **k)
+            logging.info(f"Prompt: {prompt}")
+
+            if dry_run:
+                return prompt
+
+            answer = self.get_answer_from_llm(prompt)
+
+            if isinstance(answer, str):
+                logging.info(f"Answer: {answer}")
+                return answer
+            else:
+                return self._stream_query_response(answer)
+        finally:
+            if config:
+                # Restore previous config
+                self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)      
 
     def chat(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
         """
@@ -217,39 +227,49 @@ class BaseLlm(JSONSerializable):
         :return: The answer to the query or the dry run result
         :rtype: str
         """
-        query_config = config or self.config
+        try:
+            if config:
+                # A config instance passed to this method will only be applied temporarily, for one call.
+                # So we will save the previous config and restore it at the end of the execution.
+                # For this we use the serializer.
+                prev_config = self.config.serialize()
+                self.config = config
+
+            if self.is_docs_site_instance:
+                self.config.template = DOCS_SITE_PROMPT_TEMPLATE
+                self.config.number_documents = 5
+            k = {}
+            if self.online:
+                k["web_search_result"] = self.access_search_and_get_results(input_query)
 
-        if self.is_docs_site_instance:
-            query_config.template = DOCS_SITE_PROMPT_TEMPLATE
-            query_config.number_documents = 5
-        k = {}
-        if self.online:
-            k["web_search_result"] = self.access_search_and_get_results(input_query)
-
-        self.update_history()
+            self.update_history()
 
-        prompt = self.generate_prompt(input_query, contexts, **k)
-        logging.info(f"Prompt: {prompt}")
+            prompt = self.generate_prompt(input_query, contexts, **k)
+            logging.info(f"Prompt: {prompt}")
 
-        if dry_run:
-            return prompt
+            if dry_run:
+                return prompt
 
-        answer = self.get_answer_from_llm(prompt)
+            answer = self.get_answer_from_llm(prompt)
 
-        self.memory.chat_memory.add_user_message(input_query)
+            self.memory.chat_memory.add_user_message(input_query)
 
-        if isinstance(answer, str):
-            self.memory.chat_memory.add_ai_message(answer)
-            logging.info(f"Answer: {answer}")
+            if isinstance(answer, str):
+                self.memory.chat_memory.add_ai_message(answer)
+                logging.info(f"Answer: {answer}")
 
-            # NOTE: Adding to history before and after. This could be seen as redundant.
-            # If we change it, we have to change the tests (no big deal).
-            self.update_history()
+                # NOTE: Adding to history before and after. This could be seen as redundant.
+                # If we change it, we have to change the tests (no big deal).
+                self.update_history()
 
-            return answer
-        else:
-            # this is a streamed response and needs to be handled differently.
-            return self._stream_chat_response(answer)
+                return answer
+            else:
+                # this is a streamed response and needs to be handled differently.
+                return self._stream_chat_response(answer)
+        finally:
+            if config:
+                # Restore previous config
+                self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
 
     @staticmethod
     def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]: