|
@@ -155,7 +155,7 @@ class EmbedChain:
|
|
|
prompt = template.substitute(context = context, query = input_query)
|
|
|
return prompt
|
|
|
|
|
|
- def get_answer_from_llm(self, prompt):
|
|
|
+ def get_answer_from_llm(self, prompt, config: ChatConfig):
|
|
|
"""
|
|
|
Gets an answer based on the given query and context by passing it
|
|
|
to an LLM.
|
|
@@ -165,7 +165,7 @@ class EmbedChain:
|
|
|
:return: The answer.
|
|
|
"""
|
|
|
|
|
|
- return self.get_llm_model_answer(prompt)
|
|
|
+ return self.get_llm_model_answer(prompt, config)
|
|
|
|
|
|
def query(self, input_query, config: QueryConfig = None):
|
|
|
"""
|
|
@@ -181,7 +181,7 @@ class EmbedChain:
|
|
|
config = QueryConfig()
|
|
|
context = self.retrieve_from_database(input_query)
|
|
|
prompt = self.generate_prompt(input_query, context, config.template)
|
|
|
- answer = self.get_answer_from_llm(prompt)
|
|
|
+ answer = self.get_answer_from_llm(prompt, config)
|
|
|
return answer
|
|
|
|
|
|
def generate_chat_prompt(self, input_query, context, chat_history=''):
|
|
@@ -224,7 +224,7 @@ class EmbedChain:
|
|
|
context,
|
|
|
chat_history=chat_history,
|
|
|
)
|
|
|
- answer = self.get_answer_from_llm(prompt)
|
|
|
+ answer = self.get_answer_from_llm(prompt, config)
|
|
|
memory.chat_memory.add_user_message(input_query)
|
|
|
if isinstance(answer, str):
|
|
|
memory.chat_memory.add_ai_message(answer)
|
|
@@ -295,14 +295,8 @@ class App(EmbedChain):
|
|
|
config = InitConfig()
|
|
|
super().__init__(config)
|
|
|
|
|
|
- def get_llm_model_answer(self, prompt):
|
|
|
- stream_response = self.config.stream_response
|
|
|
- if stream_response:
|
|
|
- return self._stream_llm_model_response(prompt)
|
|
|
- else:
|
|
|
- return self._get_llm_model_response(prompt)
|
|
|
+ def get_llm_model_answer(self, prompt, config: ChatConfig):
|
|
|
|
|
|
- def _get_llm_model_response(self, prompt, stream_response = False):
|
|
|
messages = []
|
|
|
messages.append({
|
|
|
"role": "user", "content": prompt
|
|
@@ -313,20 +307,18 @@ class App(EmbedChain):
|
|
|
temperature=0,
|
|
|
max_tokens=1000,
|
|
|
top_p=1,
|
|
|
- stream=stream_response
|
|
|
+ stream=config.stream
|
|
|
)
|
|
|
|
|
|
- if stream_response:
|
|
|
- # This contains the entire completions object. Needs to be sanitised
|
|
|
- return response
|
|
|
+ if config.stream:
|
|
|
+ return self._stream_llm_model_response(response)
|
|
|
else:
|
|
|
return response["choices"][0]["message"]["content"]
|
|
|
|
|
|
- def _stream_llm_model_response(self, prompt):
|
|
|
+ def _stream_llm_model_response(self, response):
|
|
|
"""
|
|
|
This is a generator for streaming response from the OpenAI completions API
|
|
|
"""
|
|
|
- response = self._get_llm_model_response(prompt, True)
|
|
|
for line in response:
|
|
|
chunk = line['choices'][0].get('delta', {}).get('content', '')
|
|
|
yield chunk
|