浏览代码

feat: add logging (#206)

cachho 2 年之前
父节点
当前提交
b3cf834186
共有 3 个文件被更改,包括 28 次插入3 次删除
  1. 1 0
      README.md
  2. 18 2
      embedchain/config/InitConfig.py
  3. 9 1
      embedchain/embedchain.py

+ 1 - 0
README.md

@@ -444,6 +444,7 @@ This section describes all possible config options.
 
 |option|description|type|default|
 |---|---|---|---|
+|log_level|log level|string|WARNING|
 |ef|embedding function|chromadb.utils.embedding_functions|{text-embedding-ada-002}|
 |db|vector database (experimental)|BaseVectorDB|ChromaDB|
 

+ 18 - 2
embedchain/config/InitConfig.py

@@ -1,4 +1,5 @@
 import os
+import logging
 
 from embedchain.config.BaseConfig import BaseConfig
 
@@ -6,11 +7,15 @@ class InitConfig(BaseConfig):
     """
     Config to initialize an embedchain `App` instance.
     """
-    def __init__(self, ef=None, db=None):
+
+    def __init__(self, log_level=None, ef=None, db=None):
         """
+        :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
         :param ef: Optional. Embedding function to use.
         :param db: Optional. (Vector) database to use for embeddings.
         """
+        self._setup_logging(log_level)
+
         # Embedding Function
         if ef is None:
             from chromadb.utils import embedding_functions
@@ -30,7 +35,18 @@ class InitConfig(BaseConfig):
 
         return
 
-
     def _set_embedding_function(self, ef):
         self.ef = ef
         return
+
+    def _setup_logging(self, debug_level):
+        level = logging.WARNING  # Default level
+        if debug_level is not None:
+            level = getattr(logging, debug_level.upper(), None)
+            if not isinstance(level, int):
+                raise ValueError(f'Invalid log level: {debug_level}')
+
+        logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s",
+                            level=level)
+        self.logger = logging.getLogger(__name__)
+        return

+ 9 - 1
embedchain/embedchain.py

@@ -1,5 +1,6 @@
 import openai
 import os
+import logging
 from string import Template
 
 from chromadb.utils import embedding_functions
@@ -181,7 +182,9 @@ class EmbedChain:
             config = QueryConfig()
         context = self.retrieve_from_database(input_query)
         prompt = self.generate_prompt(input_query, context, config.template)
+        logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
+        logging.info(f"Answer: {answer}")
         return answer
 
     def generate_chat_prompt(self, input_query, context, chat_history=''):
@@ -224,13 +227,16 @@ class EmbedChain:
             context,
             chat_history=chat_history,
         )
+        logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
         memory.chat_memory.add_user_message(input_query)
+        
         if isinstance(answer, str):
             memory.chat_memory.add_ai_message(answer)
+            logging.info(f"Answer: {answer}")
             return answer
         else:
-            #this is a streamed response and needs to be handled differently
+            #this is a streamed response and needs to be handled differently.
             return self._stream_chat_response(answer)
 
     def _stream_chat_response(self, answer):
@@ -239,6 +245,7 @@ class EmbedChain:
             streamed_answer.join(chunk)
             yield chunk
         memory.chat_memory.add_ai_message(streamed_answer)
+        logging.info(f"Answer: {streamed_answer}")
           
 
     def dry_run(self, input_query, config: QueryConfig = None):
@@ -258,6 +265,7 @@ class EmbedChain:
             config = QueryConfig()
         context = self.retrieve_from_database(input_query)
         prompt = self.generate_prompt(input_query, context, config.template)
+        logging.info(f"Prompt: {prompt}")
         return prompt
 
     def count(self):