Browse Source

feat: system prompt (#448)

cachho 2 years ago
parent
commit
849de5e8ab

+ 1 - 1
docs/advanced/configuration.mdx

@@ -68,7 +68,7 @@ einstein_chat_template = Template("""
 
         Human: $query
         Albert Einstein:""")
-query_config = QueryConfig(template=einstein_chat_template)
+query_config = QueryConfig(template=einstein_chat_template, system_prompt="You are Albert Einstein.")
 queries = [
         "Where did you complete your studies?",
         "Why did you win nobel prize?",

+ 2 - 0
docs/advanced/query_configuration.mdx

@@ -65,6 +65,8 @@ _coming soon_
 |top_p|Controls the diversity of words. Higher values (closer to 1) make word selection more diverse, lower values make words less diverse.|float|1|
 |history|include conversation history from your client or database.|any (recommendation: list[str])|None|
 |stream|control if response is streamed back to the user.|bool|False|
+|deployment_name|t.b.a.|str|None|
+|system_prompt|System prompt string. Unused if none.|str|None|
 
 ## ChatConfig
 

+ 2 - 0
embedchain/apps/App.py

@@ -25,6 +25,8 @@ class App(EmbedChain):
 
     def get_llm_model_answer(self, prompt, config: ChatConfig):
         messages = []
+        if config.system_prompt:
+            messages.append({"role": "system", "content": config.system_prompt})
         messages.append({"role": "user", "content": prompt})
         response = openai.ChatCompletion.create(
             model=config.model or "gpt-3.5-turbo-0613",

+ 11 - 7
embedchain/apps/CustomApp.py

@@ -1,5 +1,5 @@
 import logging
-from typing import List
+from typing import List, Optional
 
 from langchain.schema import BaseMessage
 
@@ -84,7 +84,7 @@ class CustomApp(EmbedChain):
         if config.top_p and config.top_p != 1:
             logging.warning("Config option `top_p` is not supported by this model.")
 
-        messages = CustomApp._get_messages(prompt)
+        messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
 
         return chat(messages).content
 
@@ -97,7 +97,7 @@ class CustomApp(EmbedChain):
         if config.max_tokens and config.max_tokens != 1000:
             logging.warning("Config option `max_tokens` is not supported by this model.")
 
-        messages = CustomApp._get_messages(prompt)
+        messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
 
         return chat(messages).content
 
@@ -110,7 +110,7 @@ class CustomApp(EmbedChain):
         if config.top_p and config.top_p != 1:
             logging.warning("Config option `top_p` is not supported by this model.")
 
-        messages = CustomApp._get_messages(prompt)
+        messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
 
         return chat(messages).content
 
@@ -133,15 +133,19 @@ class CustomApp(EmbedChain):
         if config.top_p and config.top_p != 1:
             logging.warning("Config option `top_p` is not supported by this model.")
 
-        messages = CustomApp._get_messages(prompt)
+        messages = CustomApp._get_messages(prompt, system_prompt=config.system_prompt)
 
         return chat(messages).content
 
     @staticmethod
-    def _get_messages(prompt: str) -> List[BaseMessage]:
+    def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[BaseMessage]:
         from langchain.schema import HumanMessage, SystemMessage
 
-        return [SystemMessage(content="You are a helpful assistant."), HumanMessage(content=prompt)]
+        messages = []
+        if system_prompt:
+            messages.append(SystemMessage(content=system_prompt))
+        messages.append(HumanMessage(content=prompt))
+        return messages
 
     def _stream_llm_model_response(self, response):
         """

+ 4 - 2
embedchain/apps/Llama2App.py

@@ -2,7 +2,7 @@ import os
 
 from langchain.llms import Replicate
 
-from embedchain.config import AppConfig
+from embedchain.config import AppConfig, ChatConfig
 from embedchain.embedchain import EmbedChain
 
 
@@ -27,8 +27,10 @@ class Llama2App(EmbedChain):
 
         super().__init__(config)
 
-    def get_llm_model_answer(self, prompt, config: AppConfig = None):
+    def get_llm_model_answer(self, prompt, config: ChatConfig = None):
         # TODO: Move the model and other inputs into config
+        if config.system_prompt:
+            raise ValueError("Llama2App does not support `system_prompt`")
         llm = Replicate(
             model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
             input={"temperature": 0.75, "max_length": 500, "top_p": 1},

+ 3 - 0
embedchain/apps/OpenSourceApp.py

@@ -55,6 +55,9 @@ class OpenSourceApp(EmbedChain):
                 "OpenSourceApp does not support switching models at runtime. Please create a new app instance."
             )
 
+        if config.system_prompt:
+            raise ValueError("OpenSourceApp does not support `system_prompt`")
+
         response = self.instance.generate(
             prompt=prompt,
             streaming=config.stream,

+ 5 - 0
embedchain/config/ChatConfig.py

@@ -1,4 +1,5 @@
 from string import Template
+from typing import Optional
 
 from embedchain.config.QueryConfig import QueryConfig
 
@@ -34,6 +35,7 @@ class ChatConfig(QueryConfig):
         top_p=None,
         stream: bool = False,
         deployment_name=None,
+        system_prompt: Optional[str] = None,
     ):
         """
         Initializes the ChatConfig instance.
@@ -51,6 +53,8 @@ class ChatConfig(QueryConfig):
         (closer to 1) make word selection more diverse, lower values make words less
         diverse.
         :param stream: Optional. Control if response is streamed back to the user
+        :param deployment_name: t.b.a.
+        :param system_prompt: Optional. System prompt string.
         :raises ValueError: If the template is not valid as template should contain
         $context and $query and $history
         """
@@ -70,6 +74,7 @@ class ChatConfig(QueryConfig):
             history=[0],
             stream=stream,
             deployment_name=deployment_name,
+            system_prompt=system_prompt,
         )
 
     def set_history(self, history):

+ 5 - 0
embedchain/config/QueryConfig.py

@@ -1,5 +1,6 @@
 import re
 from string import Template
+from typing import Optional
 
 from embedchain.config.BaseConfig import BaseConfig
 
@@ -63,6 +64,7 @@ class QueryConfig(BaseConfig):
         history=None,
         stream: bool = False,
         deployment_name=None,
+        system_prompt: Optional[str] = None,
     ):
         """
         Initializes the QueryConfig instance.
@@ -81,6 +83,8 @@ class QueryConfig(BaseConfig):
         diverse.
         :param history: Optional. A list of strings to consider as history.
         :param stream: Optional. Control if response is streamed back to user
+        :param deployment_name: t.b.a.
+        :param system_prompt: Optional. System prompt string.
         :raises ValueError: If the template is not valid as template should
         contain $context and $query (and optionally $history).
         """
@@ -108,6 +112,7 @@ class QueryConfig(BaseConfig):
         self.model = model
         self.top_p = top_p if top_p else 1
         self.deployment_name = deployment_name
+        self.system_prompt = system_prompt
 
         if self.validate_template(template):
             self.template = template

+ 17 - 0
tests/embedchain/test_query.py

@@ -41,3 +41,20 @@ class TestApp(unittest.TestCase):
         self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
         self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
         mock_answer.assert_called_once()
+
+    @patch("openai.ChatCompletion.create")
+    def test_query_config_passing(self, mock_create):
+        mock_create.return_value = {"choices": [{"message": {"content": "response"}}]}  # Mock response
+
+        config = AppConfig()
+        chat_config = QueryConfig(system_prompt="Test system prompt")
+        app = App(config=config)
+
+        app.get_llm_model_answer("Test query", chat_config)
+
+        # Test systemp_prompt: Check that the 'create' method was called with the correct 'messages' argument
+        messages_arg = mock_create.call_args.kwargs["messages"]
+        self.assertEqual(messages_arg[0]["role"], "system")
+        self.assertEqual(messages_arg[0]["content"], "Test system prompt")
+
+        # TODO: Add tests for other config variables