chroma_db.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import logging
  2. import os
  3. import chromadb
  4. from chromadb.utils import embedding_functions
  5. from embedchain.vectordb.base_vector_db import BaseVectorDB
  6. class ChromaDB(BaseVectorDB):
  7. """Vector database using ChromaDB."""
  8. def __init__(self, db_dir=None, ef=None, host=None, port=None):
  9. if ef:
  10. self.ef = ef
  11. else:
  12. self.ef = embedding_functions.OpenAIEmbeddingFunction(
  13. api_key=os.getenv("OPENAI_API_KEY"),
  14. organization_id=os.getenv("OPENAI_ORGANIZATION"),
  15. model_name="text-embedding-ada-002",
  16. )
  17. if host and port:
  18. logging.info(f"Connecting to ChromaDB server: {host}:{port}")
  19. self.client_settings = chromadb.config.Settings(
  20. chroma_api_impl="rest",
  21. chroma_server_host=host,
  22. chroma_server_http_port=port,
  23. )
  24. else:
  25. if db_dir is None:
  26. db_dir = "db"
  27. self.client_settings = chromadb.config.Settings(
  28. chroma_db_impl="duckdb+parquet",
  29. persist_directory=db_dir,
  30. anonymized_telemetry=False,
  31. )
  32. super().__init__()
  33. def _get_or_create_db(self):
  34. """Get or create the database."""
  35. return chromadb.Client(self.client_settings)
  36. def _get_or_create_collection(self):
  37. """Get or create the collection."""
  38. return self.client.get_or_create_collection(
  39. "embedchain_store",
  40. embedding_function=self.ef,
  41. )