|
@@ -1,5 +1,5 @@
|
|
|
import logging
|
|
|
-from typing import Iterable, Union
|
|
|
+from typing import Iterable, Union, Optional
|
|
|
|
|
|
from embedchain.config import ChatConfig, OpenSourceAppConfig
|
|
|
from embedchain.embedchain import EmbedChain
|
|
@@ -18,10 +18,11 @@ class OpenSourceApp(EmbedChain):
|
|
|
query(query): finds answer to the given query using vector database and LLM.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, config: OpenSourceAppConfig = None):
|
|
|
+ def __init__(self, config: OpenSourceAppConfig = None, system_prompt: Optional[str] = None):
|
|
|
"""
|
|
|
:param config: OpenSourceAppConfig instance to load as configuration. Optional.
|
|
|
`ef` defaults to open source.
|
|
|
+ :param system_prompt: System prompt string. Optional.
|
|
|
"""
|
|
|
logging.info("Loading open source embedding model. This may take some time...") # noqa:E501
|
|
|
if not config:
|
|
@@ -33,7 +34,7 @@ class OpenSourceApp(EmbedChain):
|
|
|
self.instance = OpenSourceApp._get_instance(config.model)
|
|
|
|
|
|
logging.info("Successfully loaded open source embedding model.")
|
|
|
- super().__init__(config)
|
|
|
+ super().__init__(config, system_prompt)
|
|
|
|
|
|
def get_llm_model_answer(self, prompt, config: ChatConfig):
|
|
|
return self._get_gpt4all_answer(prompt=prompt, config=config)
|
|
@@ -55,7 +56,7 @@ class OpenSourceApp(EmbedChain):
|
|
|
"OpenSourceApp does not support switching models at runtime. Please create a new app instance."
|
|
|
)
|
|
|
|
|
|
- if config.system_prompt:
|
|
|
+ if self.system_prompt or config.system_prompt:
|
|
|
raise ValueError("OpenSourceApp does not support `system_prompt`")
|
|
|
|
|
|
response = self.instance.generate(
|