瀏覽代碼

System prompt at App level (#484)

Co-authored-by: Taranjeet Singh <reachtotj@gmail.com>
Dev Khant 1 年之前
父節點
當前提交
ec9f454ad1

+ 14 - 4
embedchain/apps/App.py

@@ -1,3 +1,5 @@
+from typing import Optional
+
 import openai
 
 from embedchain.config import AppConfig, ChatConfig
@@ -14,19 +16,27 @@ class App(EmbedChain):
     dry_run(query): test your prompt without consuming tokens.
     """
 
-    def __init__(self, config: AppConfig = None):
+    def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
         """
         :param config: AppConfig instance to load as configuration. Optional.
+        :param system_prompt: System prompt string. Optional.
         """
         if config is None:
             config = AppConfig()
 
-        super().__init__(config)
+        super().__init__(config, system_prompt)
 
     def get_llm_model_answer(self, prompt, config: ChatConfig):
         messages = []
-        if config.system_prompt:
-            messages.append({"role": "system", "content": config.system_prompt})
+        system_prompt = (
+            self.system_prompt
+            if self.system_prompt is not None
+            else config.system_prompt
+            if config.system_prompt is not None
+            else None
+        )
+        if system_prompt:
+            messages.append({"role": "system", "content": system_prompt})
         messages.append({"role": "user", "content": prompt})
         response = openai.ChatCompletion.create(
             model=config.model or "gpt-3.5-turbo-0613",

+ 6 - 2
embedchain/apps/CustomApp.py

@@ -18,10 +18,11 @@ class CustomApp(EmbedChain):
     dry_run(query): test your prompt without consuming tokens.
     """
 
-    def __init__(self, config: CustomAppConfig = None):
+    def __init__(self, config: CustomAppConfig = None, system_prompt: Optional[str] = None):
         """
         :param config: Optional. `CustomAppConfig` instance to load as configuration.
         :raises ValueError: Config must be provided for custom app
+        :param system_prompt: Optional. System prompt string.
         """
         if config is None:
             raise ValueError("Config must be provided for custom app")
@@ -34,7 +35,7 @@ class CustomApp(EmbedChain):
             # Because these models run locally, they should have an instance running when the custom app is created
             self.open_source_app = OpenSourceApp(config=config.open_source_app_config)
 
-        super().__init__(config)
+        super().__init__(config, system_prompt)
 
     def set_llm_model(self, provider: Providers):
         self.provider = provider
@@ -51,6 +52,9 @@ class CustomApp(EmbedChain):
                 "Streaming responses have not been implemented for this model yet. Please disable."
             )
 
+        if config.system_prompt is None and self.system_prompt is not None:
+            config.system_prompt = self.system_prompt
+
         try:
             if self.provider == Providers.OPENAI:
                 return CustomApp._get_openai_answer(prompt, config)

+ 5 - 3
embedchain/apps/Llama2App.py

@@ -1,4 +1,5 @@
 import os
+from typing import Optional
 
 from langchain.llms import Replicate
 
@@ -15,9 +16,10 @@ class Llama2App(EmbedChain):
     query(query): finds answer to the given query using vector database and LLM.
     """
 
-    def __init__(self, config: AppConfig = None):
+    def __init__(self, config: AppConfig = None, system_prompt: Optional[str] = None):
         """
         :param config: AppConfig instance to load as configuration. Optional.
+        :param system_prompt: System prompt string. Optional.
         """
         if "REPLICATE_API_TOKEN" not in os.environ:
             raise ValueError("Please set the REPLICATE_API_TOKEN environment variable.")
@@ -25,11 +27,11 @@ class Llama2App(EmbedChain):
         if config is None:
             config = AppConfig()
 
-        super().__init__(config)
+        super().__init__(config, system_prompt)
 
     def get_llm_model_answer(self, prompt, config: ChatConfig = None):
         # TODO: Move the model and other inputs into config
-        if config.system_prompt:
+        if self.system_prompt or config.system_prompt:
             raise ValueError("Llama2App does not support `system_prompt`")
         llm = Replicate(
             model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",

+ 5 - 4
embedchain/apps/OpenSourceApp.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Iterable, Union
+from typing import Iterable, Union, Optional
 
 from embedchain.config import ChatConfig, OpenSourceAppConfig
 from embedchain.embedchain import EmbedChain
@@ -18,10 +18,11 @@ class OpenSourceApp(EmbedChain):
     query(query): finds answer to the given query using vector database and LLM.
     """
 
-    def __init__(self, config: OpenSourceAppConfig = None):
+    def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
         """
         :param config: OpenSourceAppConfig instance to load as configuration. Optional.
         `ef` defaults to open source.
+        :param system_prompt: System prompt string. Optional.
         """
         logging.info("Loading open source embedding model. This may take some time...")  # noqa:E501
         if not config:
@@ -33,7 +34,7 @@ class OpenSourceApp(EmbedChain):
         self.instance = OpenSourceApp._get_instance(config.model)
 
         logging.info("Successfully loaded open source embedding model.")
-        super().__init__(config)
+        super().__init__(config, system_prompt)
 
     def get_llm_model_answer(self, prompt, config: ChatConfig):
         return self._get_gpt4all_answer(prompt=prompt, config=config)
@@ -55,7 +56,7 @@ class OpenSourceApp(EmbedChain):
                 "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
             )
 
-        if config.system_prompt:
+        if self.system_prompt or config.system_prompt:
             raise ValueError("OpenSourceApp does not support `system_prompt`")
 
         response = self.instance.generate(

+ 3 - 1
embedchain/embedchain.py

@@ -33,15 +33,17 @@ CONFIG_FILE = os.path.join(CONFIG_DIR, "config.json")
 
 
 class EmbedChain:
-    def __init__(self, config: BaseAppConfig):
+    def __init__(self, config: BaseAppConfig, system_prompt: Optional[str] = None):
         """
         Initializes the EmbedChain instance, sets up a vector DB client and
         creates a collection.
 
         :param config: BaseAppConfig instance to load as configuration.
+        :param system_prompt: Optional. System prompt string.
         """
 
         self.config = config
+        self.system_prompt = system_prompt
         self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
         self.db = self.config.db
         self.user_asks = []

+ 17 - 2
tests/embedchain/test_query.py

@@ -43,7 +43,7 @@ class TestApp(unittest.TestCase):
         mock_answer.assert_called_once()
 
     @patch("openai.ChatCompletion.create")
-    def test_query_config_passing(self, mock_create):
+    def test_query_config_app_passing(self, mock_create):
         mock_create.return_value = {"choices": [{"message": {"content": "response"}}]}  # Mock response
 
         config = AppConfig()
@@ -52,9 +52,24 @@ class TestApp(unittest.TestCase):
 
         app.get_llm_model_answer("Test query", chat_config)
 
-        # Test systemp_prompt: Check that the 'create' method was called with the correct 'messages' argument
+        # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
         messages_arg = mock_create.call_args.kwargs["messages"]
         self.assertEqual(messages_arg[0]["role"], "system")
         self.assertEqual(messages_arg[0]["content"], "Test system prompt")
 
         # TODO: Add tests for other config variables
+
+    @patch("openai.ChatCompletion.create")
+    def test_app_passing(self, mock_create):
+        mock_create.return_value = {"choices": [{"message": {"content": "response"}}]}  # Mock response
+
+        config = AppConfig()
+        chat_config = QueryConfig()
+        app = App(config=config, system_prompt="Test system prompt")
+
+        app.get_llm_model_answer("Test query", chat_config)
+
+        # Test system_prompt: Check that the 'create' method was called with the correct 'messages' argument
+        messages_arg = mock_create.call_args.kwargs["messages"]
+        self.assertEqual(messages_arg[0]["role"], "system")
+        self.assertEqual(messages_arg[0]["content"], "Test system prompt")