Pārlūkot izejas kodu

[Refactor] Improve logging package wide (#1315)

Deshraj Yadav 1 gadu atpakaļ
vecāks
revīzija
3616eaadb4
54 mainītis faili ar 263 papildinājumiem un 231 dzēšanām
  1. 0 1
      embedchain/alembic.ini
  2. 15 15
      embedchain/app.py
  3. 11 9
      embedchain/bots/discord.py
  4. 7 5
      embedchain/bots/slack.py
  5. 5 3
      embedchain/bots/whatsapp.py
  6. 4 2
      embedchain/cache.py
  7. 3 1
      embedchain/chunkers/base_chunker.py
  8. 10 8
      embedchain/client.py
  9. 6 4
      embedchain/config/base_app_config.py
  10. 3 1
      embedchain/config/llm/base.py
  11. 13 10
      embedchain/embedchain.py
  12. 3 1
      embedchain/embedder/nvidia.py
  13. 3 1
      embedchain/evaluation/metrics/answer_relevancy.py
  14. 3 1
      embedchain/evaluation/metrics/groundedness.py
  15. 4 2
      embedchain/helpers/json_serializable.py
  16. 3 1
      embedchain/llm/anthropic.py
  17. 2 1
      embedchain/llm/aws_bedrock.py
  18. 3 1
      embedchain/llm/azure_openai.py
  19. 9 7
      embedchain/llm/base.py
  20. 3 1
      embedchain/llm/google.py
  21. 3 1
      embedchain/llm/huggingface.py
  22. 2 1
      embedchain/llm/openai.py
  23. 3 1
      embedchain/llm/vertex_ai.py
  24. 4 2
      embedchain/loaders/beehiiv.py
  25. 5 3
      embedchain/loaders/directory_loader.py
  26. 4 2
      embedchain/loaders/discord.py
  27. 4 2
      embedchain/loaders/discourse.py
  28. 4 2
      embedchain/loaders/docs_site_loader.py
  29. 3 1
      embedchain/loaders/gmail.py
  30. 3 1
      embedchain/loaders/mysql.py
  31. 3 1
      embedchain/loaders/notion.py
  32. 3 1
      embedchain/loaders/postgres.py
  33. 5 3
      embedchain/loaders/sitemap.py
  34. 7 5
      embedchain/loaders/slack.py
  35. 4 2
      embedchain/loaders/substack.py
  36. 3 1
      embedchain/loaders/web_page.py
  37. 6 4
      embedchain/loaders/youtube_channel.py
  38. 5 3
      embedchain/memory/base.py
  39. 4 2
      embedchain/memory/message.py
  40. 0 1
      embedchain/store/assistants.py
  41. 32 30
      embedchain/utils/misc.py
  42. 4 1
      embedchain/vectordb/chroma.py
  43. 3 1
      embedchain/vectordb/elasticsearch.py
  44. 4 2
      embedchain/vectordb/opensearch.py
  45. 3 1
      embedchain/vectordb/pinecone.py
  46. 3 1
      embedchain/vectordb/zilliz.py
  47. 6 3
      examples/api_server/api_server.py
  48. 6 4
      examples/nextjs/nextjs_discord/app.py
  49. 5 3
      examples/nextjs/nextjs_slack/app.py
  50. 11 9
      examples/rest-api/main.py
  51. 1 1
      pyproject.toml
  52. 0 19
      tests/llm/test_anthrophic.py
  53. 0 27
      tests/llm/test_azure_openai.py
  54. 0 15
      tests/loaders/test_discourse.py

+ 0 - 1
embedchain/alembic.ini

@@ -91,7 +91,6 @@ keys = console
 keys = generic
 
 [logger_root]
-level = WARN
 handlers = console
 qualname =
 

+ 15 - 15
embedchain/app.py

@@ -32,6 +32,8 @@ from embedchain.utils.misc import validate_config
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.chroma import ChromaDB
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class App(EmbedChain):
@@ -50,10 +52,10 @@ class App(EmbedChain):
         embedding_model: BaseEmbedder = None,
         llm: BaseLlm = None,
         config_data: dict = None,
-        log_level=logging.WARN,
         auto_deploy: bool = False,
         chunker: ChunkerConfig = None,
         cache_config: CacheConfig = None,
+        log_level: int = logging.WARN,
     ):
         """
         Initialize a new `App` instance.
@@ -68,8 +70,6 @@ class App(EmbedChain):
         :type llm: BaseLlm, optional
         :param config_data: Config dictionary, defaults to None
         :type config_data: dict, optional
-        :param log_level: Log level to use, defaults to logging.WARN
-        :type log_level: int, optional
         :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
         :type auto_deploy: bool, optional
         :raises Exception: If an error occurs while creating the pipeline
@@ -83,13 +83,12 @@ class App(EmbedChain):
         if name and config:
             raise Exception("Cannot provide both name and config. Please provide only one of them.")
 
-        # logging.basicConfig(level=log_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
-        self.logger = logging.getLogger(__name__)
-
+        logger.debug("4.0")
         # Initialize the metadata db for the app
         setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
         init_db()
 
+        logger.debug("4.0")
         self.auto_deploy = auto_deploy
         # Store the dict config as an attribute to be able to send it
         self.config_data = config_data if (config_data and validate_config(config_data)) else None
@@ -119,6 +118,7 @@ class App(EmbedChain):
         self.llm = llm or OpenAILlm()
         self._init_db()
 
+        logger.debug("4.1")
         # Session for the metadata db
         self.db_session = get_session()
 
@@ -126,6 +126,7 @@ class App(EmbedChain):
         if self.cache_config is not None:
             self._init_cache()
 
+        logger.debug("4.2")
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -238,7 +239,7 @@ class App(EmbedChain):
                 response.raise_for_status()
                 return response.status_code == 200
         except Exception as e:
-            self.logger.exception(f"Error occurred during file upload: {str(e)}")
+            logger.exception(f"Error occurred during file upload: {str(e)}")
             print("❌ Error occurred during file upload!")
             return False
 
@@ -272,7 +273,7 @@ class App(EmbedChain):
                 metadata = {"file_path": data_value, "s3_key": s3_key}
                 data_value = presigned_url
             else:
-                self.logger.error(f"File upload failed for hash: {data_hash}")
+                logger.error(f"File upload failed for hash: {data_hash}")
                 return False
         else:
             if data_type == "qna_pair":
@@ -336,6 +337,7 @@ class App(EmbedChain):
         :return: An instance of the App class.
         :rtype: App
         """
+        logger.debug("6")
         # Backward compatibility for yaml_path
         if yaml_path and not config_path:
             config_path = yaml_path
@@ -357,15 +359,13 @@ class App(EmbedChain):
         elif config and isinstance(config, dict):
             config_data = config
         else:
-            logging.error(
+            logger.error(
                 "Please provide either a config file path (YAML or JSON) or a config dictionary. Falling back to defaults because no config is provided.",  # noqa: E501
             )
             config_data = {}
 
-        try:
-            validate_config(config_data)
-        except Exception as e:
-            raise Exception(f"Error occurred while validating the config. Error: {str(e)}")
+        # Validate the config
+        validate_config(config_data)
 
         app_config_data = config_data.get("app", {}).get("config", {})
         vector_db_config_data = config_data.get("vectordb", {})
@@ -477,12 +477,12 @@ class App(EmbedChain):
             EvalMetric.GROUNDEDNESS.value,
         ]
 
-        logging.info(f"Collecting data from {len(queries)} questions for evaluation...")
+        logger.info(f"Collecting data from {len(queries)} questions for evaluation...")
         dataset = []
         for q, a, c in zip(queries, answers, contexts):
             dataset.append(EvalData(question=q, answer=a, contexts=c))
 
-        logging.info(f"Evaluating {len(dataset)} data points...")
+        logger.info(f"Evaluating {len(dataset)} data points...")
         result = {}
         with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
             future_to_metric = {executor.submit(self._eval, dataset, metric): metric for metric in metrics}

+ 11 - 9
embedchain/bots/discord.py

@@ -17,6 +17,8 @@ except ModuleNotFoundError:
     ) from None
 
 
+logger = logging.getLogger(__name__)
+
 intents = discord.Intents.default()
 intents.message_content = True
 client = discord.Client(intents=intents)
@@ -37,7 +39,7 @@ class DiscordBot(BaseBot):
             self.add(data)
             response = f"Added data from: {data}"
         except Exception:
-            logging.exception(f"Failed to add data {data}.")
+            logger.exception(f"Failed to add data {data}.")
             response = "Some error occurred while adding data."
         return response
 
@@ -45,7 +47,7 @@ class DiscordBot(BaseBot):
         try:
             response = self.query(message)
         except Exception:
-            logging.exception(f"Failed to query {message}.")
+            logger.exception(f"Failed to query {message}.")
             response = "An error occurred. Please try again!"
         return response
 
@@ -60,7 +62,7 @@ class DiscordBot(BaseBot):
 async def query_command(interaction: discord.Interaction, question: str):
     await interaction.response.defer()
     member = client.guilds[0].get_member(client.user.id)
-    logging.info(f"User: {member}, Query: {question}")
+    logger.info(f"User: {member}, Query: {question}")
     try:
         answer = discord_bot.ask_bot(question)
         if args.include_question:
@@ -70,20 +72,20 @@ async def query_command(interaction: discord.Interaction, question: str):
         await interaction.followup.send(response)
     except Exception as e:
         await interaction.followup.send("An error occurred. Please try again!")
-        logging.error("Error occurred during 'query' command:", e)
+        logger.error("Error occurred during 'query' command:", e)
 
 
 @tree.command(name="add", description="add new content to the embedchain database")
 async def add_command(interaction: discord.Interaction, url_or_text: str):
     await interaction.response.defer()
     member = client.guilds[0].get_member(client.user.id)
-    logging.info(f"User: {member}, Add: {url_or_text}")
+    logger.info(f"User: {member}, Add: {url_or_text}")
     try:
         response = discord_bot.add_data(url_or_text)
         await interaction.followup.send(response)
     except Exception as e:
         await interaction.followup.send("An error occurred. Please try again!")
-        logging.error("Error occurred during 'add' command:", e)
+        logger.error("Error occurred during 'add' command:", e)
 
 
 @tree.command(name="ping", description="Simple ping pong command")
@@ -96,7 +98,7 @@ async def on_app_command_error(interaction: discord.Interaction, error: discord.
     if isinstance(error, commands.CommandNotFound):
         await interaction.followup.send("Invalid command. Please refer to the documentation for correct syntax.")
     else:
-        logging.error("Error occurred during command execution:", error)
+        logger.error("Error occurred during command execution:", error)
 
 
 @client.event
@@ -104,8 +106,8 @@ async def on_ready():
     # TODO: Sync in admin command, to not hit rate limits.
     # This might be overkill for most users, and it would require to set a guild or user id, where sync is allowed.
     await tree.sync()
-    logging.debug("Command tree synced")
-    logging.info(f"Logged in as {client.user.name}")
+    logger.debug("Command tree synced")
+    logger.info(f"Logged in as {client.user.name}")
 
 
 def start_command():

+ 7 - 5
embedchain/bots/slack.py

@@ -19,6 +19,8 @@ except ModuleNotFoundError:
     ) from None
 
 
+logger = logging.getLogger(__name__)
+
 SLACK_BOT_TOKEN = os.environ.get("SLACK_BOT_TOKEN")
 
 
@@ -42,10 +44,10 @@ class SlackBot(BaseBot):
                     try:
                         response = self.chat_bot.chat(question)
                         self.send_slack_message(message["channel"], response)
-                        logging.info("Query answered successfully!")
+                        logger.info("Query answered successfully!")
                     except Exception as e:
                         self.send_slack_message(message["channel"], "An error occurred. Please try again!")
-                        logging.error("Error occurred during 'query' command:", e)
+                        logger.error("Error occurred during 'query' command:", e)
                 elif text.startswith("add"):
                     _, data_type, url_or_text = text.split(" ", 2)
                     if url_or_text.startswith("<") and url_or_text.endswith(">"):
@@ -55,10 +57,10 @@ class SlackBot(BaseBot):
                         self.send_slack_message(message["channel"], f"Added {data_type} : {url_or_text}")
                     except ValueError as e:
                         self.send_slack_message(message["channel"], f"Error: {str(e)}")
-                        logging.error("Error occurred during 'add' command:", e)
+                        logger.error("Error occurred during 'add' command:", e)
                     except Exception as e:
                         self.send_slack_message(message["channel"], f"Failed to add {data_type} : {url_or_text}")
-                        logging.error("Error occurred during 'add' command:", e)
+                        logger.error("Error occurred during 'add' command:", e)
 
     def send_slack_message(self, channel, message):
         response = self.client.chat_postMessage(channel=channel, text=message)
@@ -68,7 +70,7 @@ class SlackBot(BaseBot):
         app = Flask(__name__)
 
         def signal_handler(sig, frame):
-            logging.info("\nGracefully shutting down the SlackBot...")
+            logger.info("\nGracefully shutting down the SlackBot...")
             sys.exit(0)
 
         signal.signal(signal.SIGINT, signal_handler)

+ 5 - 3
embedchain/bots/whatsapp.py

@@ -8,6 +8,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 
 from .base import BaseBot
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class WhatsAppBot(BaseBot):
@@ -35,7 +37,7 @@ class WhatsAppBot(BaseBot):
             self.add(data)
             response = f"Added data from: {data}"
         except Exception:
-            logging.exception(f"Failed to add data {data}.")
+            logger.exception(f"Failed to add data {data}.")
             response = "Some error occurred while adding data."
         return response
 
@@ -43,7 +45,7 @@ class WhatsAppBot(BaseBot):
         try:
             response = self.query(message)
         except Exception:
-            logging.exception(f"Failed to query {message}.")
+            logger.exception(f"Failed to query {message}.")
             response = "An error occurred. Please try again!"
         return response
 
@@ -51,7 +53,7 @@ class WhatsAppBot(BaseBot):
         app = self.flask.Flask(__name__)
 
         def signal_handler(sig, frame):
-            logging.info("\nGracefully shutting down the WhatsAppBot...")
+            logger.info("\nGracefully shutting down the WhatsAppBot...")
             sys.exit(0)
 
         signal.signal(signal.SIGINT, signal_handler)

+ 4 - 2
embedchain/cache.py

@@ -14,6 +14,8 @@ from gptcache.similarity_evaluation.distance import \
 from gptcache.similarity_evaluation.exact_match import \
     ExactMatchEvaluation  # noqa: F401
 
+logger = logging.getLogger(__name__)
+
 
 def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
     return data["input_query"]
@@ -24,12 +26,12 @@ def gptcache_data_manager(vector_dimension):
 
 
 def gptcache_data_convert(cache_data):
-    logging.info("[Cache] Cache hit, returning cache data...")
+    logger.info("[Cache] Cache hit, returning cache data...")
     return cache_data
 
 
 def gptcache_update_cache_callback(llm_data, update_cache_func, *args, **kwargs):
-    logging.info("[Cache] Cache missed, updating cache...")
+    logger.info("[Cache] Cache missed, updating cache...")
     update_cache_func(Answer(llm_data, CacheDataType.STR))
     return llm_data
 

+ 3 - 1
embedchain/chunkers/base_chunker.py

@@ -6,6 +6,8 @@ from embedchain.config.add_config import ChunkerConfig
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.models.data_type import DataType
 
+logger = logging.getLogger(__name__)
+
 
 class BaseChunker(JSONSerializable):
     def __init__(self, text_splitter):
@@ -27,7 +29,7 @@ class BaseChunker(JSONSerializable):
         chunk_ids = []
         id_map = {}
         min_chunk_size = config.min_chunk_size if config is not None else 1
-        logging.info(f"Skipping chunks smaller than {min_chunk_size} characters")
+        logger.info(f"Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]

+ 10 - 8
embedchain/client.py

@@ -7,6 +7,8 @@ import requests
 
 from embedchain.constants import CONFIG_DIR, CONFIG_FILE
 
+logger = logging.getLogger(__name__)
+
 
 class Client:
     def __init__(self, api_key=None, host="https://apiv2.embedchain.ai"):
@@ -24,7 +26,7 @@ class Client:
         else:
             if "api_key" in self.config_data:
                 self.api_key = self.config_data["api_key"]
-                logging.info("API key loaded successfully!")
+                logger.info("API key loaded successfully!")
             else:
                 raise ValueError(
                     "You are not logged in. Please obtain an API key from https://app.embedchain.ai/settings/keys/"
@@ -64,7 +66,7 @@ class Client:
         with open(CONFIG_FILE, "w") as config_file:
             json.dump(self.config_data, config_file, indent=4)
 
-        logging.info("API key saved successfully!")
+        logger.info("API key saved successfully!")
 
     def clear(self):
         if "api_key" in self.config_data:
@@ -72,17 +74,17 @@ class Client:
             with open(CONFIG_FILE, "w") as config_file:
                 json.dump(self.config_data, config_file, indent=4)
             self.api_key = None
-            logging.info("API key deleted successfully!")
+            logger.info("API key deleted successfully!")
         else:
-            logging.warning("API key not found in the configuration file.")
+            logger.warning("API key not found in the configuration file.")
 
     def update(self, api_key):
         if self.check(api_key):
             self.api_key = api_key
             self.save()
-            logging.info("API key updated successfully!")
+            logger.info("API key updated successfully!")
         else:
-            logging.warning("Invalid API key provided. API key not updated.")
+            logger.warning("Invalid API key provided. API key not updated.")
 
     def check(self, api_key):
         validation_url = f"{self.host}/api/v1/accounts/api_keys/validate/"
@@ -90,8 +92,8 @@ class Client:
         if response.status_code == 200:
             return True
         else:
-            logging.warning(f"Response from API: {response.text}")
-            logging.warning("Invalid API key. Unable to validate.")
+            logger.warning(f"Response from API: {response.text}")
+            logger.warning("Invalid API key. Unable to validate.")
             return False
 
     def get(self):

+ 6 - 4
embedchain/config/base_app_config.py

@@ -5,6 +5,8 @@ from embedchain.config.base_config import BaseConfig
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.vectordb.base import BaseVectorDB
 
+logger = logging.getLogger(__name__)
+
 
 class BaseAppConfig(BaseConfig, JSONSerializable):
     """
@@ -42,15 +44,15 @@ class BaseAppConfig(BaseConfig, JSONSerializable):
 
         if db:
             self._db = db
-            logging.warning(
+            logger.warning(
                 "DEPRECATION WARNING: Please supply the database as the second parameter during app init. "
                 "Such as `app(config=config, db=db)`."
             )
 
         if collection_name:
-            logging.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
+            logger.warning("DEPRECATION WARNING: Please supply the collection name to the database config.")
         return
 
     def _setup_logging(self, log_level):
-        logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
-        self.logger = logging.getLogger(__name__)
+        logger.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=log_level)
+        self.logger = logger.getLogger(__name__)

+ 3 - 1
embedchain/config/llm/base.py

@@ -6,6 +6,8 @@ from typing import Any, Optional
 from embedchain.config.base_config import BaseConfig
 from embedchain.helpers.json_serializable import register_deserializable
 
+logger = logging.getLogger(__name__)
+
 DEFAULT_PROMPT = """
 You are a Q&A expert system. Your responses must always be rooted in the context provided for each query. Here are some guidelines to follow:
 
@@ -147,7 +149,7 @@ class BaseLlmConfig(BaseConfig):
         :raises ValueError: Stream is not boolean
         """
         if template is not None:
-            logging.warning(
+            logger.warning(
                 "The `template` argument is deprecated and will be removed in a future version. "
                 + "Please use `prompt` instead."
             )

+ 13 - 10
embedchain/embedchain.py

@@ -25,6 +25,8 @@ from embedchain.vectordb.base import BaseVectorDB
 
 load_dotenv()
 
+logger = logging.getLogger(__name__)
+
 
 class EmbedChain(JSONSerializable):
     def __init__(
@@ -143,10 +145,10 @@ class EmbedChain(JSONSerializable):
 
         try:
             DataType(source)
-            logging.warning(
+            logger.warning(
                 f"""Starting from version v0.0.40, Embedchain can automatically detect the data type. So, in the `add` method, the argument order has changed. You no longer need to specify '{source}' for the `source` argument. So the code snippet will be `.add("{data_type}", "{source}")`"""  # noqa #E501
             )
-            logging.warning(
+            logger.warning(
                 "Embedchain is swapping the arguments for you. This functionality might be deprecated in the future, so please adjust your code."  # noqa #E501
             )
             source, data_type = data_type, source
@@ -157,7 +159,7 @@ class EmbedChain(JSONSerializable):
             try:
                 data_type = DataType(data_type)
             except ValueError:
-                logging.info(
+                logger.info(
                     f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`"  # noqa: E501
                 )
                 data_type = DataType.CUSTOM
@@ -190,12 +192,12 @@ class EmbedChain(JSONSerializable):
         try:
             self.db_session.commit()
         except Exception as e:
-            logging.error(f"Error adding data source: {e}")
+            logger.error(f"Error adding data source: {e}")
             self.db_session.rollback()
 
         if dry_run:
             data_chunks_info = {"chunks": documents, "metadata": metadatas, "count": len(documents), "type": data_type}
-            logging.debug(f"Dry run info : {data_chunks_info}")
+            logger.debug(f"Dry run info : {data_chunks_info}")
             return data_chunks_info
 
         # Send anonymous telemetry
@@ -490,7 +492,7 @@ class EmbedChain(JSONSerializable):
             contexts_data_for_llm_query = contexts
 
         if self.cache_config is not None:
-            logging.info("Cache enabled. Checking cache...")
+            logger.info("Cache enabled. Checking cache...")
             answer = adapt(
                 llm_handler=self.llm.query,
                 cache_data_convert=gptcache_data_convert,
@@ -562,7 +564,7 @@ class EmbedChain(JSONSerializable):
         self.llm.update_history(app_id=self.config.id, session_id=session_id)
 
         if self.cache_config is not None:
-            logging.info("Cache enabled. Checking cache...")
+            logger.debug("Cache enabled. Checking cache...")
             cache_id = f"{session_id}--{self.config.id}"
             answer = adapt(
                 llm_handler=self.llm.chat,
@@ -575,6 +577,7 @@ class EmbedChain(JSONSerializable):
                 dry_run=dry_run,
             )
         else:
+            logger.debug("Cache disabled. Running chat without cache.")
             answer = self.llm.chat(
                 input_query=input_query, contexts=contexts_data_for_llm_query, config=config, dry_run=dry_run
             )
@@ -652,7 +655,7 @@ class EmbedChain(JSONSerializable):
             self.db_session.query(ChatHistory).filter_by(app_id=self.config.id).delete()
             self.db_session.commit()
         except Exception as e:
-            logging.error(f"Error deleting data sources: {e}")
+            logger.error(f"Error deleting data sources: {e}")
             self.db_session.rollback()
             return None
         self.db.reset()
@@ -694,11 +697,11 @@ class EmbedChain(JSONSerializable):
             self.db_session.query(DataSource).filter_by(hash=source_id, app_id=self.config.id).delete()
             self.db_session.commit()
         except Exception as e:
-            logging.error(f"Error deleting data sources: {e}")
+            logger.error(f"Error deleting data sources: {e}")
             self.db_session.rollback()
             return None
         self.db.delete(where={"hash": source_id})
-        logging.info(f"Successfully deleted {source_id}")
+        logger.info(f"Successfully deleted {source_id}")
         # Send anonymous telemetry
         if self.config.collect_metrics:
             self.telemetry.capture(event_name="delete", properties=self._telemetry_props)

+ 3 - 1
embedchain/embedder/nvidia.py

@@ -8,6 +8,8 @@ from embedchain.config import BaseEmbedderConfig
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.models import VectorDimensions
 
+logger = logging.getLogger(__name__)
+
 
 class NvidiaEmbedder(BaseEmbedder):
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
@@ -17,7 +19,7 @@ class NvidiaEmbedder(BaseEmbedder):
         super().__init__(config=config)
 
         model = self.config.model or "nvolveqa_40k"
-        logging.info(f"Using NVIDIA embedding model: {model}")
+        logger.info(f"Using NVIDIA embedding model: {model}")
         embedder = NVIDIAEmbeddings(model=model)
         embedding_fn = BaseEmbedder._langchain_default_concept(embedder)
         self.set_embedding_fn(embedding_fn=embedding_fn)

+ 3 - 1
embedchain/evaluation/metrics/answer_relevancy.py

@@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import AnswerRelevanceConfig
 from embedchain.evaluation.base import BaseMetric
 from embedchain.utils.evaluation import EvalData, EvalMetric
 
+logger = logging.getLogger(__name__)
+
 
 class AnswerRelevance(BaseMetric):
     """
@@ -88,6 +90,6 @@ class AnswerRelevance(BaseMetric):
                 try:
                     results.append(future.result())
                 except Exception as e:
-                    logging.error(f"Error evaluating answer relevancy for {data}: {e}")
+                    logger.error(f"Error evaluating answer relevancy for {data}: {e}")
 
         return np.mean(results) if results else 0.0

+ 3 - 1
embedchain/evaluation/metrics/groundedness.py

@@ -12,6 +12,8 @@ from embedchain.config.evaluation.base import GroundednessConfig
 from embedchain.evaluation.base import BaseMetric
 from embedchain.utils.evaluation import EvalData, EvalMetric
 
+logger = logging.getLogger(__name__)
+
 
 class Groundedness(BaseMetric):
     """
@@ -97,6 +99,6 @@ class Groundedness(BaseMetric):
                     score = future.result()
                     results.append(score)
                 except Exception as e:
-                    logging.error(f"Error while evaluating groundedness for data point {data}: {e}")
+                    logger.error(f"Error while evaluating groundedness for data point {data}: {e}")
 
         return np.mean(results) if results else 0.0

+ 4 - 2
embedchain/helpers/json_serializable.py

@@ -8,6 +8,8 @@ T = TypeVar("T", bound="JSONSerializable")
 # NOTE: Through inheritance, all of our classes should be children of JSONSerializable. (highest level)
 # NOTE: The @register_deserializable decorator should be added to all user facing child classes. (lowest level)
 
+logger = logging.getLogger(__name__)
+
 
 def register_deserializable(cls: Type[T]) -> Type[T]:
     """
@@ -57,7 +59,7 @@ class JSONSerializable:
         try:
             return json.dumps(self, default=self._auto_encoder, ensure_ascii=False)
         except Exception as e:
-            logging.error(f"Serialization error: {e}")
+            logger.error(f"Serialization error: {e}")
             return "{}"
 
     @classmethod
@@ -79,7 +81,7 @@ class JSONSerializable:
         try:
             return json.loads(json_str, object_hook=cls._auto_decoder)
         except Exception as e:
-            logging.error(f"Deserialization error: {e}")
+            logger.error(f"Deserialization error: {e}")
             # Return a default instance in case of failure
             return cls()
 

+ 3 - 1
embedchain/llm/anthropic.py

@@ -6,6 +6,8 @@ from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class AnthropicLlm(BaseLlm):
@@ -26,7 +28,7 @@ class AnthropicLlm(BaseLlm):
         )
 
         if config.max_tokens and config.max_tokens != 1000:
-            logging.warning("Config option `max_tokens` is not supported by this model.")
+            logger.warning("Config option `max_tokens` is not supported by this model.")
 
         messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
 

+ 2 - 1
embedchain/llm/aws_bedrock.py

@@ -38,7 +38,8 @@ class AWSBedrockLlm(BaseLlm):
         }
 
         if config.stream:
-            from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+            from langchain.callbacks.streaming_stdout import \
+                StreamingStdOutCallbackHandler
 
             callbacks = [StreamingStdOutCallbackHandler()]
             llm = Bedrock(**kwargs, streaming=config.stream, callbacks=callbacks)

+ 3 - 1
embedchain/llm/azure_openai.py

@@ -5,6 +5,8 @@ from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class AzureOpenAILlm(BaseLlm):
@@ -31,7 +33,7 @@ class AzureOpenAILlm(BaseLlm):
         )
 
         if config.top_p and config.top_p != 1:
-            logging.warning("Config option `top_p` is not supported by this model.")
+            logger.warning("Config option `top_p` is not supported by this model.")
 
         messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
 

+ 9 - 7
embedchain/llm/base.py

@@ -12,6 +12,8 @@ from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.memory.base import ChatHistory
 from embedchain.memory.message import ChatMessage
 
+logger = logging.getLogger(__name__)
+
 
 class BaseLlm(JSONSerializable):
     def __init__(self, config: Optional[BaseLlmConfig] = None):
@@ -108,7 +110,7 @@ class BaseLlm(JSONSerializable):
                 )
             else:
                 # If we can't swap in the default, we still proceed but tell users that the history is ignored.
-                logging.warning(
+                logger.warning(
                     "Your bot contains a history, but prompt does not include `$history` key. History is ignored."
                 )
                 prompt = self.config.prompt.substitute(context=context_string, query=input_query)
@@ -159,7 +161,7 @@ class BaseLlm(JSONSerializable):
                 'Searching requires extra dependencies. Install with `pip install --upgrade "embedchain[dataloaders]"`'
             ) from None
         search = DuckDuckGoSearchRun()
-        logging.info(f"Access search to get answers for {input_query}")
+        logger.info(f"Access search to get answers for {input_query}")
         return search.run(input_query)
 
     @staticmethod
@@ -175,7 +177,7 @@ class BaseLlm(JSONSerializable):
         for chunk in answer:
             streamed_answer = streamed_answer + chunk
             yield chunk
-        logging.info(f"Answer: {streamed_answer}")
+        logger.info(f"Answer: {streamed_answer}")
 
     def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
         """
@@ -214,13 +216,13 @@ class BaseLlm(JSONSerializable):
             if self.online:
                 k["web_search_result"] = self.access_search_and_get_results(input_query)
             prompt = self.generate_prompt(input_query, contexts, **k)
-            logging.info(f"Prompt: {prompt}")
+            logger.info(f"Prompt: {prompt}")
             if dry_run:
                 return prompt
 
             answer = self.get_answer_from_llm(prompt)
             if isinstance(answer, str):
-                logging.info(f"Answer: {answer}")
+                logger.info(f"Answer: {answer}")
                 return answer
             else:
                 return self._stream_response(answer)
@@ -270,14 +272,14 @@ class BaseLlm(JSONSerializable):
                 k["web_search_result"] = self.access_search_and_get_results(input_query)
 
             prompt = self.generate_prompt(input_query, contexts, **k)
-            logging.info(f"Prompt: {prompt}")
+            logger.info(f"Prompt: {prompt}")
 
             if dry_run:
                 return prompt
 
             answer = self.get_answer_from_llm(prompt)
             if isinstance(answer, str):
-                logging.info(f"Answer: {answer}")
+                logger.info(f"Answer: {answer}")
                 return answer
             else:
                 # this is a streamed response and needs to be handled differently.

+ 3 - 1
embedchain/llm/google.py

@@ -10,6 +10,8 @@ from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class GoogleLlm(BaseLlm):
@@ -36,7 +38,7 @@ class GoogleLlm(BaseLlm):
 
     def _get_answer(self, prompt: str) -> Union[str, Generator[Any, Any, None]]:
         model_name = self.config.model or "gemini-pro"
-        logging.info(f"Using Google LLM model: {model_name}")
+        logger.info(f"Using Google LLM model: {model_name}")
         model = genai.GenerativeModel(model_name=model_name)
 
         generation_config_params = {

+ 3 - 1
embedchain/llm/huggingface.py

@@ -11,6 +11,8 @@ from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class HuggingFaceLlm(BaseLlm):
@@ -58,7 +60,7 @@ class HuggingFaceLlm(BaseLlm):
             raise ValueError("`top_p` must be > 0.0 and < 1.0")
 
         model = config.model
-        logging.info(f"Using HuggingFaceHub with model {model}")
+        logger.info(f"Using HuggingFaceHub with model {model}")
         llm = HuggingFaceHub(
             huggingfacehub_api_token=os.environ["HUGGINGFACE_ACCESS_TOKEN"],
             repo_id=model,

+ 2 - 1
embedchain/llm/openai.py

@@ -65,7 +65,8 @@ class OpenAILlm(BaseLlm):
         messages: list[BaseMessage],
     ) -> str:
         from langchain.output_parsers.openai_tools import JsonOutputToolsParser
-        from langchain_core.utils.function_calling import convert_to_openai_tool
+        from langchain_core.utils.function_calling import \
+            convert_to_openai_tool
 
         openai_tools = [convert_to_openai_tool(tools)]
         chat = chat.bind(tools=openai_tools).pipe(JsonOutputToolsParser())

+ 3 - 1
embedchain/llm/vertex_ai.py

@@ -9,6 +9,8 @@ from embedchain.config import BaseLlmConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.llm.base import BaseLlm
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class VertexAILlm(BaseLlm):
@@ -28,7 +30,7 @@ class VertexAILlm(BaseLlm):
     @staticmethod
     def _get_answer(prompt: str, config: BaseLlmConfig) -> str:
         if config.top_p and config.top_p != 1:
-            logging.warning("Config option `top_p` is not supported by this model.")
+            logger.warning("Config option `top_p` is not supported by this model.")
 
         messages = BaseLlm._get_messages(prompt, system_prompt=config.system_prompt)
 

+ 4 - 2
embedchain/loaders/beehiiv.py

@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import is_readable
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class BeehiivLoader(BaseLoader):
@@ -90,9 +92,9 @@ class BeehiivLoader(BaseLoader):
                 if is_readable(data):
                     return data
                 else:
-                    logging.warning(f"Page is not readable (too many invalid characters): {link}")
+                    logger.warning(f"Page is not readable (too many invalid characters): {link}")
             except ParserRejectedMarkup as e:
-                logging.error(f"Failed to parse {link}: {e}")
+                logger.error(f"Failed to parse {link}: {e}")
             return None
 
         for link in links:

+ 5 - 3
embedchain/loaders/directory_loader.py

@@ -10,6 +10,8 @@ from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.text_file import TextFileLoader
 from embedchain.utils.misc import detect_datatype
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class DirectoryLoader(BaseLoader):
@@ -27,12 +29,12 @@ class DirectoryLoader(BaseLoader):
         if not directory_path.is_dir():
             raise ValueError(f"Invalid path: {path}")
 
-        logging.info(f"Loading data from directory: {path}")
+        logger.info(f"Loading data from directory: {path}")
         data_list = self._process_directory(directory_path)
         doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
 
         for error in self.errors:
-            logging.warning(error)
+            logger.warning(error)
 
         return {"doc_id": doc_id, "data": data_list}
 
@@ -46,7 +48,7 @@ class DirectoryLoader(BaseLoader):
                 loader = self._predict_loader(file_path)
                 data_list.extend(loader.load_data(str(file_path))["data"])
             elif file_path.is_dir():
-                logging.info(f"Loading data from directory: {file_path}")
+                logger.info(f"Loading data from directory: {file_path}")
         return data_list
 
     def _predict_loader(self, file_path: Path) -> BaseLoader:

+ 4 - 2
embedchain/loaders/discord.py

@@ -5,6 +5,8 @@ import os
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class DiscordLoader(BaseLoader):
@@ -102,7 +104,7 @@ class DiscordLoader(BaseLoader):
 
         class DiscordClient(discord.Client):
             async def on_ready(self) -> None:
-                logging.info("Logged on as {0}!".format(self.user))
+                logger.info("Logged on as {0}!".format(self.user))
                 try:
                     channel = self.get_channel(int(channel_id))
                     if not isinstance(channel, discord.TextChannel):
@@ -121,7 +123,7 @@ class DiscordLoader(BaseLoader):
                                 messages.append(DiscordLoader._format_message(thread_message))
 
                 except Exception as e:
-                    logging.error(e)
+                    logger.error(e)
                     await self.close()
                 finally:
                     await self.close()

+ 4 - 2
embedchain/loaders/discourse.py

@@ -8,6 +8,8 @@ import requests
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import clean_string
 
+logger = logging.getLogger(__name__)
+
 
 class DiscourseLoader(BaseLoader):
     def __init__(self, config: Optional[dict[str, Any]] = None):
@@ -35,7 +37,7 @@ class DiscourseLoader(BaseLoader):
         try:
             response.raise_for_status()
         except Exception as e:
-            logging.error(f"Failed to load post {post_id}: {e}")
+            logger.error(f"Failed to load post {post_id}: {e}")
             return
         response_data = response.json()
         post_contents = clean_string(response_data.get("raw"))
@@ -56,7 +58,7 @@ class DiscourseLoader(BaseLoader):
         self._check_query(query)
         data = []
         data_contents = []
-        logging.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
+        logger.info(f"Searching data on discourse url: {self.domain}, for query: {query}")
         search_url = f"{self.domain}search.json?q={query}"
         response = requests.get(search_url)
         try:

+ 4 - 2
embedchain/loaders/docs_site_loader.py

@@ -15,6 +15,8 @@ except ImportError:
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class DocsSiteLoader(BaseLoader):
@@ -28,7 +30,7 @@ class DocsSiteLoader(BaseLoader):
 
         response = requests.get(url)
         if response.status_code != 200:
-            logging.info(f"Failed to fetch the website: {response.status_code}")
+            logger.info(f"Failed to fetch the website: {response.status_code}")
             return
 
         soup = BeautifulSoup(response.text, "html.parser")
@@ -53,7 +55,7 @@ class DocsSiteLoader(BaseLoader):
     def _load_data_from_url(url: str) -> list:
         response = requests.get(url)
         if response.status_code != 200:
-            logging.info(f"Failed to fetch the website: {response.status_code}")
+            logger.info(f"Failed to fetch the website: {response.status_code}")
             return []
 
         soup = BeautifulSoup(response.content, "html.parser")

+ 3 - 1
embedchain/loaders/gmail.py

@@ -22,6 +22,8 @@ except ImportError:
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import clean_string
 
+logger = logging.getLogger(__name__)
+
 
 class GmailReader:
     SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
@@ -114,7 +116,7 @@ class GmailLoader(BaseLoader):
     def load_data(self, query: str):
         reader = GmailReader(query=query)
         emails = reader.load_emails()
-        logging.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
+        logger.info(f"Gmail Loader: {len(emails)} emails found for query '{query}'")
 
         data = []
         for email in emails:

+ 3 - 1
embedchain/loaders/mysql.py

@@ -5,6 +5,8 @@ from typing import Any, Optional
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import clean_string
 
+logger = logging.getLogger(__name__)
+
 
 class MySQLLoader(BaseLoader):
     def __init__(self, config: Optional[dict[str, Any]]):
@@ -32,7 +34,7 @@ class MySQLLoader(BaseLoader):
             self.connection = sqlconnector.connection.MySQLConnection(**config)
             self.cursor = self.connection.cursor()
         except (sqlconnector.Error, IOError) as err:
-            logging.info(f"Connection failed: {err}")
+            logger.info(f"Connection failed: {err}")
             raise ValueError(
                 f"Unable to connect with the given config: {config}.",
                 "Please provide the correct configuration to load data from you MySQL DB. \

+ 3 - 1
embedchain/loaders/notion.py

@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import clean_string
 
+logger = logging.getLogger(__name__)
+
 
 class NotionDocument:
     """
@@ -98,7 +100,7 @@ class NotionLoader(BaseLoader):
 
         id = source[-32:]
         formatted_id = f"{id[:8]}-{id[8:12]}-{id[12:16]}-{id[16:20]}-{id[20:]}"
-        logging.debug(f"Extracted notion page id as: {formatted_id}")
+        logger.debug(f"Extracted notion page id as: {formatted_id}")
 
         integration_token = os.getenv("NOTION_INTEGRATION_TOKEN")
         reader = NotionPageLoader(integration_token=integration_token)

+ 3 - 1
embedchain/loaders/postgres.py

@@ -4,6 +4,8 @@ from typing import Any, Optional
 
 from embedchain.loaders.base_loader import BaseLoader
 
+logger = logging.getLogger(__name__)
+
 
 class PostgresLoader(BaseLoader):
     def __init__(self, config: Optional[dict[str, Any]] = None):
@@ -32,7 +34,7 @@ class PostgresLoader(BaseLoader):
                 conn_params.append(f"{key}={value}")
             config_info = " ".join(conn_params)
 
-        logging.info(f"Connecting to postrgres sql: {config_info}")
+        logger.info(f"Connecting to postrgres sql: {config_info}")
         self.connection = psycopg.connect(conninfo=config_info)
         self.cursor = self.connection.cursor()
 

+ 5 - 3
embedchain/loaders/sitemap.py

@@ -19,6 +19,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.web_page import WebPageLoader
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class SitemapLoader(BaseLoader):
@@ -41,7 +43,7 @@ class SitemapLoader(BaseLoader):
                 response.raise_for_status()
                 soup = BeautifulSoup(response.text, "xml")
             except requests.RequestException as e:
-                logging.error(f"Error fetching sitemap from URL: {e}")
+                logger.error(f"Error fetching sitemap from URL: {e}")
                 return
         elif os.path.isfile(sitemap_source):
             with open(sitemap_source, "r") as file:
@@ -60,7 +62,7 @@ class SitemapLoader(BaseLoader):
                 loader_data = web_page_loader.load_data(link)
                 return loader_data.get("data")
             except ParserRejectedMarkup as e:
-                logging.error(f"Failed to parse {link}: {e}")
+                logger.error(f"Failed to parse {link}: {e}")
             return None
 
         with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -72,6 +74,6 @@ class SitemapLoader(BaseLoader):
                     if data:
                         output.extend(data)
                 except Exception as e:
-                    logging.error(f"Error loading page {link}: {e}")
+                    logger.error(f"Error loading page {link}: {e}")
 
         return {"doc_id": doc_id, "data": output}

+ 7 - 5
embedchain/loaders/slack.py

@@ -11,6 +11,8 @@ from embedchain.utils.misc import clean_string
 
 SLACK_API_BASE_URL = "https://www.slack.com/api/"
 
+logger = logging.getLogger(__name__)
+
 
 class SlackLoader(BaseLoader):
     def __init__(self, config: Optional[dict[str, Any]] = None):
@@ -38,7 +40,7 @@ class SlackLoader(BaseLoader):
                 "SLACK_USER_TOKEN environment variables not provided. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501
             )
 
-        logging.info(f"Creating Slack Loader with config: {config}")
+        logger.info(f"Creating Slack Loader with config: {config}")
         # get slack client config params
         slack_bot_token = os.getenv("SLACK_USER_TOKEN")
         ssl_cert = ssl.create_default_context(cafile=certifi.where())
@@ -54,7 +56,7 @@ class SlackLoader(BaseLoader):
             headers=headers,
             team_id=team_id,
         )
-        logging.info("Slack Loader setup successful!")
+        logger.info("Slack Loader setup successful!")
 
     @staticmethod
     def _check_query(query):
@@ -69,7 +71,7 @@ class SlackLoader(BaseLoader):
             data = []
             data_content = []
 
-            logging.info(f"Searching slack conversations for query: {query}")
+            logger.info(f"Searching slack conversations for query: {query}")
             results = self.client.search_messages(
                 query=query,
                 sort="timestamp",
@@ -79,7 +81,7 @@ class SlackLoader(BaseLoader):
 
             messages = results.get("messages")
             num_message = len(messages)
-            logging.info(f"Found {num_message} messages for query: {query}")
+            logger.info(f"Found {num_message} messages for query: {query}")
 
             matches = messages.get("matches", [])
             for message in matches:
@@ -107,7 +109,7 @@ class SlackLoader(BaseLoader):
                 "data": data,
             }
         except Exception as e:
-            logging.warning(f"Error in loading slack data: {e}")
+            logger.warning(f"Error in loading slack data: {e}")
             raise ValueError(
                 f"Error in loading slack data: {e}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501
             ) from e

+ 4 - 2
embedchain/loaders/substack.py

@@ -9,6 +9,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import is_readable
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class SubstackLoader(BaseLoader):
@@ -90,9 +92,9 @@ class SubstackLoader(BaseLoader):
                 if is_readable(data):
                     return data
                 else:
-                    logging.warning(f"Page is not readable (too many invalid characters): {link}")
+                    logger.warning(f"Page is not readable (too many invalid characters): {link}")
             except ParserRejectedMarkup as e:
-                logging.error(f"Failed to parse {link}: {e}")
+                logger.error(f"Failed to parse {link}: {e}")
             return None
 
         for link in links:

+ 3 - 1
embedchain/loaders/web_page.py

@@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import clean_string
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class WebPageLoader(BaseLoader):
@@ -87,7 +89,7 @@ class WebPageLoader(BaseLoader):
 
         cleaned_size = len(content)
         if original_size != 0:
-            logging.info(
+            logger.info(
                 f"[{url}] Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)"  # noqa:E501
             )
 

+ 6 - 4
embedchain/loaders/youtube_channel.py

@@ -7,6 +7,8 @@ from tqdm import tqdm
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.youtube_video import YoutubeVideoLoader
 
+logger = logging.getLogger(__name__)
+
 
 class YoutubeChannelLoader(BaseLoader):
     """Loader for youtube channel."""
@@ -36,7 +38,7 @@ class YoutubeChannelLoader(BaseLoader):
                         videos = [entry["url"] for entry in info_dict["entries"]]
                         return videos
             except Exception:
-                logging.error(f"Failed to fetch youtube videos for channel: {channel_name}")
+                logger.error(f"Failed to fetch youtube videos for channel: {channel_name}")
                 return []
 
         def _load_yt_video(video_link):
@@ -45,12 +47,12 @@ class YoutubeChannelLoader(BaseLoader):
                 if each_load_data:
                     return each_load_data.get("data")
             except Exception as e:
-                logging.error(f"Failed to load youtube video {video_link}: {e}")
+                logger.error(f"Failed to load youtube video {video_link}: {e}")
             return None
 
         def _add_youtube_channel():
             video_links = _get_yt_video_links()
-            logging.info("Loading videos from youtube channel...")
+            logger.info("Loading videos from youtube channel...")
             with concurrent.futures.ThreadPoolExecutor() as executor:
                 # Submitting all tasks and storing the future object with the video link
                 future_to_video = {
@@ -67,7 +69,7 @@ class YoutubeChannelLoader(BaseLoader):
                             data.extend(results)
                             data_urls.extend([result.get("meta_data").get("url") for result in results])
                     except Exception as e:
-                        logging.error(f"Failed to process youtube video {video}: {e}")
+                        logger.error(f"Failed to process youtube video {video}: {e}")
 
         _add_youtube_channel()
         doc_id = hashlib.sha256((youtube_url + ", ".join(data_urls)).encode()).hexdigest()

+ 5 - 3
embedchain/memory/base.py

@@ -8,6 +8,8 @@ from embedchain.core.db.models import ChatHistory as ChatHistoryModel
 from embedchain.memory.message import ChatMessage
 from embedchain.memory.utils import merge_metadata_dict
 
+logger = logging.getLogger(__name__)
+
 
 class ChatHistory:
     def __init__(self) -> None:
@@ -31,11 +33,11 @@ class ChatHistory:
         try:
             self.db_session.commit()
         except Exception as e:
-            logging.error(f"Error adding chat memory to db: {e}")
+            logger.error(f"Error adding chat memory to db: {e}")
             self.db_session.rollback()
             return None
 
-        logging.info(f"Added chat memory to db with id: {memory_id}")
+        logger.info(f"Added chat memory to db with id: {memory_id}")
         return memory_id
 
     def delete(self, app_id: str, session_id: Optional[str] = None):
@@ -55,7 +57,7 @@ class ChatHistory:
         try:
             self.db_session.commit()
         except Exception as e:
-            logging.error(f"Error deleting chat history: {e}")
+            logger.error(f"Error deleting chat history: {e}")
             self.db_session.rollback()
 
     def get(

+ 4 - 2
embedchain/memory/message.py

@@ -3,6 +3,8 @@ from typing import Any, Optional
 
 from embedchain.helpers.json_serializable import JSONSerializable
 
+logger = logging.getLogger(__name__)
+
 
 class BaseMessage(JSONSerializable):
     """
@@ -52,7 +54,7 @@ class ChatMessage(JSONSerializable):
 
     def add_user_message(self, message: str, metadata: Optional[dict] = None):
         if self.human_message:
-            logging.info(
+            logger.info(
                 "Human message already exists in the chat message,\
                 overwriting it with new message."
             )
@@ -61,7 +63,7 @@ class ChatMessage(JSONSerializable):
 
     def add_ai_message(self, message: str, metadata: Optional[dict] = None):
         if self.ai_message:
-            logging.info(
+            logger.info(
                 "AI message already exists in the chat message,\
                 overwriting it with new message."
             )

+ 0 - 1
embedchain/store/assistants.py

@@ -157,7 +157,6 @@ class AIAssistant:
         log_level=logging.INFO,
         collect_metrics=True,
     ):
-
         self.name = name or "AI Assistant"
         self.data_sources = data_sources or []
         self.log_level = log_level

+ 32 - 30
embedchain/utils/misc.py

@@ -11,6 +11,8 @@ from tqdm import tqdm
 
 from embedchain.models.data_type import DataType
 
+logger = logging.getLogger(__name__)
+
 
 def parse_content(content, type):
     implemented = ["html.parser", "lxml", "lxml-xml", "xml", "html5lib"]
@@ -61,7 +63,7 @@ def parse_content(content, type):
 
     cleaned_size = len(content)
     if original_size != 0:
-        logging.info(
+        logger.info(
             f"Cleaned page size: {cleaned_size} characters, down from {original_size} (shrunk: {original_size-cleaned_size} chars, {round((1-(cleaned_size/original_size)) * 100, 2)}%)"  # noqa:E501
         )
 
@@ -208,31 +210,31 @@ def detect_datatype(source: Any) -> DataType:
         }
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCKS:
-            logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
             return DataType.YOUTUBE_VIDEO
 
         if url.netloc in {"notion.so", "notion.site"}:
-            logging.debug(f"Source of `{formatted_source}` detected as `notion`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `notion`.")
             return DataType.NOTION
 
         if url.path.endswith(".pdf"):
-            logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
             return DataType.PDF_FILE
 
         if url.path.endswith(".xml"):
-            logging.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `sitemap`.")
             return DataType.SITEMAP
 
         if url.path.endswith(".csv"):
-            logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
             return DataType.CSV
 
         if url.path.endswith(".mdx") or url.path.endswith(".md"):
-            logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
             return DataType.MDX
 
         if url.path.endswith(".docx"):
-            logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
 
         if url.path.endswith(".yaml"):
@@ -242,14 +244,14 @@ def detect_datatype(source: Any) -> DataType:
                 try:
                     yaml_content = yaml.safe_load(response.text)
                 except yaml.YAMLError as exc:
-                    logging.error(f"Error parsing YAML: {exc}")
+                    logger.error(f"Error parsing YAML: {exc}")
                     raise TypeError(f"Not a valid data type. Error loading YAML: {exc}")
 
                 if is_openapi_yaml(yaml_content):
-                    logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
+                    logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
                     return DataType.OPENAPI
                 else:
-                    logging.error(
+                    logger.error(
                         f"Source of `{formatted_source}` does not contain all the required \
                         fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
                     )
@@ -258,35 +260,35 @@ def detect_datatype(source: Any) -> DataType:
                         make sure you have all the required fields in YAML config data"
                     )
             except requests.exceptions.RequestException as e:
-                logging.error(f"Error fetching URL {formatted_source}: {e}")
+                logger.error(f"Error fetching URL {formatted_source}: {e}")
 
         if url.path.endswith(".json"):
-            logging.debug(f"Source of `{formatted_source}` detected as `json_file`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `json_file`.")
             return DataType.JSON
 
         if "docs" in url.netloc or ("docs" in url.path and url.scheme != "file"):
             # `docs_site` detection via path is not accepted for local filesystem URIs,
             # because that would mean all paths that contain `docs` are now doc sites, which is too aggressive.
-            logging.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `docs_site`.")
             return DataType.DOCS_SITE
 
         if "github.com" in url.netloc:
-            logging.debug(f"Source of `{formatted_source}` detected as `github`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `github`.")
             return DataType.GITHUB
 
         if is_google_drive_folder(url.netloc + url.path):
-            logging.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `google drive folder`.")
             return DataType.GOOGLE_DRIVE_FOLDER
 
         # If none of the above conditions are met, it's a general web page
-        logging.debug(f"Source of `{formatted_source}` detected as `web_page`.")
+        logger.debug(f"Source of `{formatted_source}` detected as `web_page`.")
         return DataType.WEB_PAGE
 
     elif not isinstance(source, str):
         # For datatypes where source is not a string.
 
         if isinstance(source, tuple) and len(source) == 2 and isinstance(source[0], str) and isinstance(source[1], str):
-            logging.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `qna_pair`.")
             return DataType.QNA_PAIR
 
         # Raise an error if it isn't a string and also not a valid non-string type (one of the previous).
@@ -300,37 +302,37 @@ def detect_datatype(source: Any) -> DataType:
         # Note: checking for string is not necessary anymore.
 
         if source.endswith(".docx"):
-            logging.debug(f"Source of `{formatted_source}` detected as `docx`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `docx`.")
             return DataType.DOCX
 
         if source.endswith(".csv"):
-            logging.debug(f"Source of `{formatted_source}` detected as `csv`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `csv`.")
             return DataType.CSV
 
         if source.endswith(".xml"):
-            logging.debug(f"Source of `{formatted_source}` detected as `xml`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `xml`.")
             return DataType.XML
 
         if source.endswith(".mdx") or source.endswith(".md"):
-            logging.debug(f"Source of `{formatted_source}` detected as `mdx`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `mdx`.")
             return DataType.MDX
 
         if source.endswith(".txt"):
-            logging.debug(f"Source of `{formatted_source}` detected as `text`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `text`.")
             return DataType.TEXT_FILE
 
         if source.endswith(".pdf"):
-            logging.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `pdf_file`.")
             return DataType.PDF_FILE
 
         if source.endswith(".yaml"):
             with open(source, "r") as file:
                 yaml_content = yaml.safe_load(file)
                 if is_openapi_yaml(yaml_content):
-                    logging.debug(f"Source of `{formatted_source}` detected as `openapi`.")
+                    logger.debug(f"Source of `{formatted_source}` detected as `openapi`.")
                     return DataType.OPENAPI
                 else:
-                    logging.error(
+                    logger.error(
                         f"Source of `{formatted_source}` does not contain all the required \
                                   fields of OpenAPI yaml. Check 'https://spec.openapis.org/oas/v3.1.0'"
                     )
@@ -340,11 +342,11 @@ def detect_datatype(source: Any) -> DataType:
                     )
 
         if source.endswith(".json"):
-            logging.debug(f"Source of `{formatted_source}` detected as `json`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `json`.")
             return DataType.JSON
 
         if os.path.exists(source) and is_readable(open(source).read()):
-            logging.debug(f"Source of `{formatted_source}` detected as `text_file`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `text_file`.")
             return DataType.TEXT_FILE
 
         # If the source is a valid file, that's not detectable as a type, an error is raised.
@@ -360,11 +362,11 @@ def detect_datatype(source: Any) -> DataType:
 
         # check if the source is valid json string
         if is_valid_json_string(source):
-            logging.debug(f"Source of `{formatted_source}` detected as `json`.")
+            logger.debug(f"Source of `{formatted_source}` detected as `json`.")
             return DataType.JSON
 
         # Use text as final fallback.
-        logging.debug(f"Source of `{formatted_source}` detected as `text`.")
+        logger.debug(f"Source of `{formatted_source}` detected as `text`.")
         return DataType.TEXT
 
 

+ 4 - 1
embedchain/vectordb/chroma.py

@@ -22,6 +22,9 @@ except RuntimeError:
     from chromadb.errors import InvalidDimensionException
 
 
+logger = logging.getLogger(__name__)
+
+
 @register_deserializable
 class ChromaDB(BaseVectorDB):
     """Vector database using ChromaDB."""
@@ -47,7 +50,7 @@ class ChromaDB(BaseVectorDB):
                     setattr(self.settings, key, value)
 
         if self.config.host and self.config.port:
-            logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
+            logger.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
             self.settings.chroma_server_host = self.config.host
             self.settings.chroma_server_http_port = self.config.port
             self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"

+ 3 - 1
embedchain/vectordb/elasticsearch.py

@@ -14,6 +14,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.utils.misc import chunks
 from embedchain.vectordb.base import BaseVectorDB
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class ElasticsearchDB(BaseVectorDB):
@@ -62,7 +64,7 @@ class ElasticsearchDB(BaseVectorDB):
         """
         This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
         """
-        logging.info(self.client.info())
+        logger.info(self.client.info())
         index_settings = {
             "mappings": {
                 "properties": {

+ 4 - 2
embedchain/vectordb/opensearch.py

@@ -19,6 +19,8 @@ from embedchain.config import OpenSearchDBConfig
 from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class OpenSearchDB(BaseVectorDB):
@@ -43,12 +45,12 @@ class OpenSearchDB(BaseVectorDB):
             **self.config.extra_params,
         )
         info = self.client.info()
-        logging.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}")
+        logger.info(f"Connected to {info['version']['distribution']}. Version: {info['version']['number']}")
         # Remove auth credentials from config after successful connection
         super().__init__(config=self.config)
 
     def _initialize(self):
-        logging.info(self.client.info())
+        logger.info(self.client.info())
         index_name = self._get_index()
         if self.client.indices.exists(index=index_name):
             print(f"Index '{index_name}' already exists.")

+ 3 - 1
embedchain/vectordb/pinecone.py

@@ -16,6 +16,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.utils.misc import chunks
 from embedchain.vectordb.base import BaseVectorDB
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class PineconeDB(BaseVectorDB):
@@ -49,7 +51,7 @@ class PineconeDB(BaseVectorDB):
         # Setup BM25Encoder if sparse vectors are to be used
         self.bm25_encoder = None
         if self.config.hybrid_search:
-            logging.info("Initializing BM25Encoder for sparse vectors..")
+            logger.info("Initializing BM25Encoder for sparse vectors..")
             self.bm25_encoder = self.config.bm25_encoder if self.config.bm25_encoder else BM25Encoder.default()
 
         # Call parent init here because embedder is needed

+ 3 - 1
embedchain/vectordb/zilliz.py

@@ -13,6 +13,8 @@ except ImportError:
         "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
     ) from None
 
+logger = logging.getLogger(__name__)
+
 
 @register_deserializable
 class ZillizVectorDB(BaseVectorDB):
@@ -62,7 +64,7 @@ class ZillizVectorDB(BaseVectorDB):
         :type name: str
         """
         if utility.has_collection(name):
-            logging.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
+            logger.info(f"[ZillizDB]: found an existing collection {name}, make sure the auto-id is disabled.")
             self.collection = Collection(name)
         else:
             fields = [

+ 6 - 3
examples/api_server/api_server.py

@@ -7,6 +7,9 @@ from embedchain import App
 app = Flask(__name__)
 
 
+logger = logging.getLogger(__name__)
+
+
 @app.route("/add", methods=["POST"])
 def add():
     data = request.get_json()
@@ -17,7 +20,7 @@ def add():
             App().add(url_or_text, data_type=data_type)
             return jsonify({"data": f"Added {data_type}: {url_or_text}"}), 200
         except Exception:
-            logging.exception(f"Failed to add {data_type=}: {url_or_text=}")
+            logger.exception(f"Failed to add {data_type=}: {url_or_text=}")
             return jsonify({"error": f"Failed to add {data_type}: {url_or_text}"}), 500
     return jsonify({"error": "Invalid request. Please provide 'data_type' and 'url_or_text' in JSON format."}), 400
 
@@ -31,7 +34,7 @@ def query():
             response = App().query(question)
             return jsonify({"data": response}), 200
         except Exception:
-            logging.exception(f"Failed to query {question=}")
+            logger.exception(f"Failed to query {question=}")
             return jsonify({"error": "An error occurred. Please try again!"}), 500
     return jsonify({"error": "Invalid request. Please provide 'question' in JSON format."}), 400
 
@@ -45,7 +48,7 @@ def chat():
             response = App().chat(question)
             return jsonify({"data": response}), 200
         except Exception:
-            logging.exception(f"Failed to chat {question=}")
+            logger.exception(f"Failed to chat {question=}")
             return jsonify({"error": "An error occurred. Please try again!"}), 500
     return jsonify({"error": "Invalid request. Please provide 'question' in JSON format."}), 400
 

+ 6 - 4
examples/nextjs/nextjs_discord/app.py

@@ -12,10 +12,12 @@ intents.message_content = True
 client = discord.Client(intents=intents)
 discord_bot_name = os.environ["DISCORD_BOT_NAME"]
 
+logger = logging.getLogger(__name__)
+
 
 class NextJSBot:
     def __init__(self) -> None:
-        logging.info("NextJS Bot powered with embedchain.")
+        logger.info("NextJS Bot powered with embedchain.")
 
     def add(self, _):
         raise ValueError("Add is not implemented yet")
@@ -31,11 +33,11 @@ class NextJSBot:
             try:
                 response = response.json()
             except Exception:
-                logging.error(f"Failed to parse response: {response}")
+                logger.error(f"Failed to parse response: {response}")
                 response = {}
             return response
         except Exception:
-            logging.exception(f"Failed to query {message}.")
+            logger.exception(f"Failed to query {message}.")
             response = "An error occurred. Please try again!"
         return response
 
@@ -49,7 +51,7 @@ NEXTJS_BOT = NextJSBot()
 
 @client.event
 async def on_ready():
-    logging.info(f"User {client.user.name} logged in with id: {client.user.id}!")
+    logger.info(f"User {client.user.name} logged in with id: {client.user.id}!")
 
 
 def _get_question(message):

+ 5 - 3
examples/nextjs/nextjs_slack/app.py

@@ -9,6 +9,8 @@ from slack_bolt.adapter.socket_mode import SocketModeHandler
 
 load_dotenv(".env")
 
+logger = logging.getLogger(__name__)
+
 
 def remove_mentions(message):
     mention_pattern = re.compile(r"<@[^>]+>")
@@ -19,7 +21,7 @@ def remove_mentions(message):
 
 class SlackBotApp:
     def __init__(self) -> None:
-        logging.info("Slack Bot using Embedchain!")
+        logger.info("Slack Bot using Embedchain!")
 
     def add(self, _):
         raise ValueError("Add is not implemented yet")
@@ -35,11 +37,11 @@ class SlackBotApp:
             try:
                 response = response.json()
             except Exception:
-                logging.error(f"Failed to parse response: {response}")
+                logger.error(f"Failed to parse response: {response}")
                 response = {}
             return response
         except Exception:
-            logging.exception(f"Failed to query {query}.")
+            logger.exception(f"Failed to query {query}.")
             response = "An error occurred. Please try again!"
         return response
 

+ 11 - 9
examples/rest-api/main.py

@@ -13,6 +13,8 @@ from utils import generate_error_message_for_api_keys
 from embedchain import App
 from embedchain.client import Client
 
+logger = logging.getLogger(__name__)
+
 Base.metadata.create_all(bind=engine)
 
 
@@ -84,7 +86,7 @@ async def create_app_using_default_config(app_id: str, config: UploadFile = None
 
         return DefaultResponse(response=f"App created successfully. App ID: {app_id}")
     except Exception as e:
-        logging.warning(str(e))
+        logger.warning(str(e))
         raise HTTPException(detail=f"Error creating app: {str(e)}", status_code=400)
 
 
@@ -114,13 +116,13 @@ async def get_datasources_associated_with_app_id(app_id: str, db: Session = Depe
         response = app.get_data_sources()
         return {"results": response}
     except ValueError as ve:
-        logging.warning(str(ve))
+        logger.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warning(str(e))
+        logger.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 
@@ -153,13 +155,13 @@ async def add_datasource_to_an_app(body: SourceApp, app_id: str, db: Session = D
         response = app.add(source=body.source, data_type=body.data_type)
         return DefaultResponse(response=response)
     except ValueError as ve:
-        logging.warning(str(ve))
+        logger.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warning(str(e))
+        logger.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 
@@ -191,13 +193,13 @@ async def query_an_app(body: QueryApp, app_id: str, db: Session = Depends(get_db
         response = app.query(body.query)
         return DefaultResponse(response=response)
     except ValueError as ve:
-        logging.warning(str(ve))
+        logger.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warning(str(e))
+        logger.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 
@@ -274,13 +276,13 @@ async def deploy_app(body: DeployAppRequest, app_id: str, db: Session = Depends(
         app.deploy()
         return DefaultResponse(response="App deployed successfully.")
     except ValueError as ve:
-        logging.warning(str(ve))
+        logger.warning(str(ve))
         raise HTTPException(
             detail=generate_error_message_for_api_keys(ve),
             status_code=400,
         )
     except Exception as e:
-        logging.warning(str(e))
+        logger.warning(str(e))
         raise HTTPException(detail=f"Error occurred: {str(e)}", status_code=400)
 
 

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.94"
+version = "0.1.95"
 description = "Simplest open source retrieval(RAG) framework"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 0 - 19
tests/llm/test_anthrophic.py

@@ -50,22 +50,3 @@ def test_get_messages(anthropic_llm):
         SystemMessage(content="Test System Prompt", additional_kwargs={}),
         HumanMessage(content="Test Prompt", additional_kwargs={}, example=False),
     ]
-
-
-def test_get_answer_max_tokens_is_provided(anthropic_llm, caplog):
-    with patch("langchain_community.chat_models.ChatAnthropic") as mock_chat:
-        mock_chat_instance = mock_chat.return_value
-        mock_chat_instance.return_value = MagicMock(content="Test Response")
-
-        prompt = "Test Prompt"
-        config = anthropic_llm.config
-        config.max_tokens = 500
-
-        response = anthropic_llm._get_answer(prompt, config)
-
-        assert response == "Test Response"
-        mock_chat.assert_called_once_with(
-            anthropic_api_key="test_api_key", temperature=config.temperature, model=config.model
-        )
-
-        assert "Config option `max_tokens` is not supported by this model." in caplog.text

+ 0 - 27
tests/llm/test_azure_openai.py

@@ -59,33 +59,6 @@ def test_get_messages(azure_openai_llm):
     ]
 
 
-def test_get_answer_top_p_is_provided(azure_openai_llm, caplog):
-    with patch("langchain_community.chat_models.AzureChatOpenAI") as mock_chat:
-        mock_chat_instance = mock_chat.return_value
-        mock_chat_instance.return_value = MagicMock(content="Test Response")
-
-        prompt = "Test Prompt"
-        config = azure_openai_llm.config
-        config.top_p = 0.5
-
-        response = azure_openai_llm._get_answer(prompt, config)
-
-        assert response == "Test Response"
-        mock_chat.assert_called_once_with(
-            deployment_name=config.deployment_name,
-            openai_api_version="2023-05-15",
-            model_name=config.model or "gpt-3.5-turbo",
-            temperature=config.temperature,
-            max_tokens=config.max_tokens,
-            streaming=config.stream,
-        )
-        mock_chat_instance.assert_called_once_with(
-            azure_openai_llm._get_messages(prompt, system_prompt=config.system_prompt)
-        )
-
-        assert "Config option `top_p` is not supported by this model." in caplog.text
-
-
 def test_when_no_deployment_name_provided():
     config = BaseLlmConfig(temperature=0.7, model="gpt-3.5-turbo", max_tokens=50, system_prompt="System Prompt")
     with pytest.raises(ValueError):

+ 0 - 15
tests/loaders/test_discourse.py

@@ -66,21 +66,6 @@ def test_discourse_loader_load_post_with_valid_post_id(discourse_loader, monkeyp
     assert "meta_data" in post_data
 
 
-def test_discourse_loader_load_post_with_invalid_post_id(discourse_loader, monkeypatch, caplog):
-    def mock_get(*args, **kwargs):
-        class MockResponse:
-            def raise_for_status(self):
-                raise requests.exceptions.RequestException("Test error")
-
-        return MockResponse()
-
-    monkeypatch.setattr(requests, "get", mock_get)
-
-    discourse_loader._load_post(123)
-
-    assert "Failed to load post" in caplog.text
-
-
 def test_discourse_loader_load_data_with_valid_query(discourse_loader, monkeypatch):
     def mock_get(*args, **kwargs):
         class MockResponse: