소스 검색

feat: add streaming support for OpenAI models (#202)

aaishikdutta 2 년 전
부모
커밋
66c4d30c60
3개의 변경된 파일58개의 추가작업 그리고 6개의 파일을 삭제
  1. 13 0
      README.md
  2. 5 1
      embedchain/config/InitConfig.py
  3. 40 5
      embedchain/embedchain.py

+ 13 - 0
README.md

@@ -204,6 +204,19 @@ from embedchain import PersonApp as ECPApp
 print(naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?"))
 # answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
 ```
+### Stream Response
+
+- You can add config to your query method to stream responses like ChatGPT does. You would require a downstream handler to render the chunk in your desirable format
+
+- To use this, instantiate App with a `InitConfig` instance passing `stream_response=True`. The following example iterates through the chunks and prints them as they appear
+```python
+app = App(InitConfig(stream_response=True))
+resp = naval_chat_bot.query("What unique capacity does Naval argue humans possess when it comes to understanding explanations or concepts?")
+
+for chunk in resp:
+    print(chunk, end="", flush=True)
+# answer: Naval argues that humans possess the unique capacity to understand explanations or concepts to the maximum extent possible in this physical reality.
+```
 
 ### Chat Interface
 

+ 5 - 1
embedchain/config/InitConfig.py

@@ -6,7 +6,7 @@ class InitConfig(BaseConfig):
     """
     Config to initialize an embedchain `App` instance.
     """
-    def __init__(self, ef=None, db=None):
+    def __init__(self, ef=None, db=None, stream_response=False):
         """
         :param ef: Optional. Embedding function to use.
         :param db: Optional. (Vector) database to use for embeddings.
@@ -27,6 +27,10 @@ class InitConfig(BaseConfig):
             self.db = ChromaDB(ef=self.ef)
         else:
             self.db = db
+        
+        if not isinstance(stream_response, bool):
+            raise ValueError("`stream_respone` should be bool")
+        self.stream_response = stream_response
 
         return
 

+ 40 - 5
embedchain/embedchain.py

@@ -164,8 +164,8 @@ class EmbedChain:
         :param context: Similar documents to the query used as context.
         :return: The answer.
         """
-        answer = self.get_llm_model_answer(prompt)
-        return answer
+        
+        return self.get_llm_model_answer(prompt)
 
     def query(self, input_query, config: QueryConfig = None):
         """
@@ -226,8 +226,20 @@ class EmbedChain:
         )
         answer = self.get_answer_from_llm(prompt)
         memory.chat_memory.add_user_message(input_query)
-        memory.chat_memory.add_ai_message(answer)
-        return answer
+        if isinstance(answer, str):
+            memory.chat_memory.add_ai_message(answer)
+            return answer
+        else:
+            #this is a streamed response and needs to be handled differently
+            return self._stream_chat_response(answer)
+
+    def _stream_chat_response(self, answer):
+        streamed_answer = ""
+        for chunk in answer:
+            streamed_answer.join(chunk)
+            yield chunk
+        memory.chat_memory.add_ai_message(streamed_answer)
+          
 
     def dry_run(self, input_query, config: QueryConfig = None):
         """
@@ -284,6 +296,13 @@ class App(EmbedChain):
         super().__init__(config)
 
     def get_llm_model_answer(self, prompt):
+        stream_response = self.config.stream_response
+        if stream_response:
+            return self._stream_llm_model_response(prompt)
+        else:
+            return self._get_llm_model_response(prompt)
+
+    def _get_llm_model_response(self, prompt, stream_response = False):
         messages = []
         messages.append({
             "role": "user", "content": prompt
@@ -294,8 +313,24 @@ class App(EmbedChain):
             temperature=0,
             max_tokens=1000,
             top_p=1,
+            stream=stream_response
         )
-        return response["choices"][0]["message"]["content"]
+
+        if stream_response:
+            # This contains the entire completions object. Needs to be sanitised
+            return response
+        else:
+            return response["choices"][0]["message"]["content"]
+    
+    def _stream_llm_model_response(self, prompt):
+        """
+        This is a generator for streaming response from the OpenAI completions API
+        """
+        response = self._get_llm_model_response(prompt, True)
+        for line in response:
+            chunk = line['choices'][0].get('delta', {}).get('content', '')
+            yield chunk
+
 
 
 class OpenSourceApp(EmbedChain):