Prechádzať zdrojové kódy

bug: Fix/stream logging (#262)

aaishikdutta 2 rokov pred
rodič
commit
4335fff153
1 zmenil súbory, kde vykonal 22 pridanie a 6 odobranie
  1. 22 6
      embedchain/embedchain.py

+ 22 - 6
embedchain/embedchain.py

@@ -53,7 +53,9 @@ class EmbedChain:
 
         data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([data_type, url, metadata])
-        self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
+        self.load_and_embed(
+            data_formatter.loader, data_formatter.chunker, url, metadata
+        )
 
     def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
         """
@@ -117,10 +119,12 @@ class EmbedChain:
 
         chunks_before_addition = self.count()
 
-         # Add metadata to each document
+        # Add metadata to each document
         metadatas_with_metadata = [meta or metadata for meta in metadatas]
 
-        self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
+        self.collection.add(
+            documents=documents, metadatas=list(metadatas_with_metadata), ids=ids
+        )
         print(
             (
                 f"Successfully saved {src}. New chunks count: "
@@ -210,9 +214,21 @@ class EmbedChain:
         contexts = self.retrieve_from_database(input_query, config)
         prompt = self.generate_prompt(input_query, contexts, config)
         logging.info(f"Prompt: {prompt}")
+
         answer = self.get_answer_from_llm(prompt, config)
-        logging.info(f"Answer: {answer}")
-        return answer
+
+        if isinstance(answer, str):
+            logging.info(f"Answer: {answer}")
+            return answer
+        else:
+            return self._stream_query_response(answer)
+
+    def _stream_query_response(self, answer):
+        streamed_answer = ""
+        for chunk in answer:
+            streamed_answer = streamed_answer + chunk
+            yield chunk
+        logging.info(f"Answer: {streamed_answer}")
 
     def chat(self, input_query, config: ChatConfig = None):
         """
@@ -254,7 +270,7 @@ class EmbedChain:
     def _stream_chat_response(self, answer):
         streamed_answer = ""
         for chunk in answer:
-            streamed_answer.join(chunk)
+            streamed_answer = streamed_answer + chunk
             yield chunk
         memory.chat_memory.add_ai_message(streamed_answer)
         logging.info(f"Answer: {streamed_answer}")