|
@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
|
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
|
|
|
from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
|
|
|
-from embedchain.config.QueryConfig import DEFAULT_PROMPT
|
|
|
+from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
|
|
from embedchain.data_formatter import DataFormatter
|
|
|
|
|
|
gpt4all_model = None
|
|
@@ -35,6 +35,7 @@ class EmbedChain:
|
|
|
self.db_client = self.config.db.client
|
|
|
self.collection = self.config.db.collection
|
|
|
self.user_asks = []
|
|
|
+ self.is_code_docs_instance = False
|
|
|
|
|
|
def add(self, data_type, url, metadata=None, config: AddConfig = None):
|
|
|
"""
|
|
@@ -56,6 +57,8 @@ class EmbedChain:
|
|
|
self.load_and_embed(
|
|
|
data_formatter.loader, data_formatter.chunker, url, metadata
|
|
|
)
|
|
|
+ if data_type in ("code_docs_page", ):
|
|
|
+ self.is_code_docs_instance = True
|
|
|
|
|
|
def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
|
|
|
"""
|
|
@@ -211,6 +214,9 @@ class EmbedChain:
|
|
|
"""
|
|
|
if config is None:
|
|
|
config = QueryConfig()
|
|
|
+ if self.is_code_docs_instance:
|
|
|
+ config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
|
|
+ config.number_documents = 5
|
|
|
contexts = self.retrieve_from_database(input_query, config)
|
|
|
prompt = self.generate_prompt(input_query, contexts, config)
|
|
|
logging.info(f"Prompt: {prompt}")
|
|
@@ -244,7 +250,9 @@ class EmbedChain:
|
|
|
"""
|
|
|
if config is None:
|
|
|
config = ChatConfig()
|
|
|
-
|
|
|
+ if self.is_code_docs_instance:
|
|
|
+ config.template = CODE_DOCS_PAGE_PROMPT_TEMPLATE
|
|
|
+ config.number_documents = 5
|
|
|
contexts = self.retrieve_from_database(input_query, config)
|
|
|
|
|
|
global memory
|