|
@@ -164,8 +164,8 @@ class EmbedChain:
|
|
|
:param context: Similar documents to the query used as context.
|
|
|
:return: The answer.
|
|
|
"""
|
|
|
- answer = self.get_llm_model_answer(prompt)
|
|
|
- return answer
|
|
|
+
|
|
|
+ return self.get_llm_model_answer(prompt)
|
|
|
|
|
|
def query(self, input_query, config: QueryConfig = None):
|
|
|
"""
|
|
@@ -226,8 +226,20 @@ class EmbedChain:
|
|
|
)
|
|
|
answer = self.get_answer_from_llm(prompt)
|
|
|
memory.chat_memory.add_user_message(input_query)
|
|
|
- memory.chat_memory.add_ai_message(answer)
|
|
|
- return answer
|
|
|
+ if isinstance(answer, str):
|
|
|
+ memory.chat_memory.add_ai_message(answer)
|
|
|
+ return answer
|
|
|
+ else:
|
|
|
+ #this is a streamed response and needs to be handled differently
|
|
|
+ return self._stream_chat_response(answer)
|
|
|
+
|
|
|
+ def _stream_chat_response(self, answer):
|
|
|
+ streamed_answer = ""
|
|
|
+ for chunk in answer:
|
|
|
+ streamed_answer.join(chunk)
|
|
|
+ yield chunk
|
|
|
+ memory.chat_memory.add_ai_message(streamed_answer)
|
|
|
+
|
|
|
|
|
|
def dry_run(self, input_query, config: QueryConfig = None):
|
|
|
"""
|
|
@@ -284,6 +296,13 @@ class App(EmbedChain):
|
|
|
super().__init__(config)
|
|
|
|
|
|
def get_llm_model_answer(self, prompt):
|
|
|
+ stream_response = self.config.stream_response
|
|
|
+ if stream_response:
|
|
|
+ return self._stream_llm_model_response(prompt)
|
|
|
+ else:
|
|
|
+ return self._get_llm_model_response(prompt)
|
|
|
+
|
|
|
+ def _get_llm_model_response(self, prompt, stream_response = False):
|
|
|
messages = []
|
|
|
messages.append({
|
|
|
"role": "user", "content": prompt
|
|
@@ -294,8 +313,24 @@ class App(EmbedChain):
|
|
|
temperature=0,
|
|
|
max_tokens=1000,
|
|
|
top_p=1,
|
|
|
+ stream=stream_response
|
|
|
)
|
|
|
- return response["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ if stream_response:
|
|
|
+ # This contains the entire completions object. Needs to be sanitised
|
|
|
+ return response
|
|
|
+ else:
|
|
|
+ return response["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ def _stream_llm_model_response(self, prompt):
|
|
|
+ """
|
|
|
+ This is a generator for streaming response from the OpenAI completions API
|
|
|
+ """
|
|
|
+ response = self._get_llm_model_response(prompt, True)
|
|
|
+ for line in response:
|
|
|
+ chunk = line['choices'][0].get('delta', {}).get('content', '')
|
|
|
+ yield chunk
|
|
|
+
|
|
|
|
|
|
|
|
|
class OpenSourceApp(EmbedChain):
|