InitConfig.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import logging
  2. import os
  3. from embedchain.config.BaseConfig import BaseConfig
  4. class InitConfig(BaseConfig):
  5. """
  6. Config to initialize an embedchain `App` instance.
  7. """
  8. def __init__(self, log_level=None, ef=None, db=None, host=None, port=None):
  9. """
  10. :param log_level: Optional. (String) Debug level
  11. ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
  12. :param ef: Optional. Embedding function to use.
  13. :param db: Optional. (Vector) database to use for embeddings.
  14. """
  15. self._setup_logging(log_level)
  16. # Embedding Function
  17. if ef is None:
  18. from chromadb.utils import embedding_functions
  19. self.ef = embedding_functions.OpenAIEmbeddingFunction(
  20. api_key=os.getenv("OPENAI_API_KEY"),
  21. organization_id=os.getenv("OPENAI_ORGANIZATION"),
  22. model_name="text-embedding-ada-002",
  23. )
  24. else:
  25. self.ef = ef
  26. if db is None:
  27. from embedchain.vectordb.chroma_db import ChromaDB
  28. self.db = ChromaDB(ef=self.ef, host=host, port=port)
  29. else:
  30. self.db = db
  31. return
  32. def _set_embedding_function(self, ef):
  33. self.ef = ef
  34. return
  35. def _setup_logging(self, debug_level):
  36. level = logging.WARNING # Default level
  37. if debug_level is not None:
  38. level = getattr(logging, debug_level.upper(), None)
  39. if not isinstance(level, int):
  40. raise ValueError(f"Invalid log level: {debug_level}")
  41. logging.basicConfig(
  42. format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level
  43. )
  44. self.logger = logging.getLogger(__name__)
  45. return