|
@@ -1,9 +1,9 @@
|
|
import re
|
|
import re
|
|
from string import Template
|
|
from string import Template
|
|
-from typing import Any, Dict, Optional
|
|
|
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
|
|
|
from embedchain.config.base_config import BaseConfig
|
|
from embedchain.config.base_config import BaseConfig
|
|
-from embedchain.helper.json_serializable import register_deserializable
|
|
|
|
|
|
+from embedchain.helpers.json_serializable import register_deserializable
|
|
|
|
|
|
DEFAULT_PROMPT = """
|
|
DEFAULT_PROMPT = """
|
|
Use the following pieces of context to answer the query at the end.
|
|
Use the following pieces of context to answer the query at the end.
|
|
@@ -68,6 +68,7 @@ class BaseLlmConfig(BaseConfig):
|
|
system_prompt: Optional[str] = None,
|
|
system_prompt: Optional[str] = None,
|
|
where: Dict[str, Any] = None,
|
|
where: Dict[str, Any] = None,
|
|
query_type: Optional[str] = None,
|
|
query_type: Optional[str] = None,
|
|
|
|
+ callbacks: Optional[List] = None,
|
|
):
|
|
):
|
|
"""
|
|
"""
|
|
Initializes a configuration class instance for the LLM.
|
|
Initializes a configuration class instance for the LLM.
|
|
@@ -98,6 +99,8 @@ class BaseLlmConfig(BaseConfig):
|
|
:type system_prompt: Optional[str], optional
|
|
:type system_prompt: Optional[str], optional
|
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
:param where: A dictionary of key-value pairs to filter the database results., defaults to None
|
|
:type where: Dict[str, Any], optional
|
|
:type where: Dict[str, Any], optional
|
|
|
|
+ :param callbacks: Langchain callback functions to use, defaults to None
|
|
|
|
+ :type callbacks: Optional[List], optional
|
|
:raises ValueError: If the template is not valid as template should
|
|
:raises ValueError: If the template is not valid as template should
|
|
contain $context and $query (and optionally $history)
|
|
contain $context and $query (and optionally $history)
|
|
:raises ValueError: Stream is not boolean
|
|
:raises ValueError: Stream is not boolean
|
|
@@ -113,6 +116,7 @@ class BaseLlmConfig(BaseConfig):
|
|
self.deployment_name = deployment_name
|
|
self.deployment_name = deployment_name
|
|
self.system_prompt = system_prompt
|
|
self.system_prompt = system_prompt
|
|
self.query_type = query_type
|
|
self.query_type = query_type
|
|
|
|
+ self.callbacks = callbacks
|
|
|
|
|
|
if type(template) is str:
|
|
if type(template) is str:
|
|
template = Template(template)
|
|
template = Template(template)
|