chroma_db.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import os
  2. import chromadb
  3. from chromadb.utils import embedding_functions
  4. from embedchain.vectordb.base_vector_db import BaseVectorDB
  5. class ChromaDB(BaseVectorDB):
  6. """Vector database using ChromaDB."""
  7. def __init__(self, db_dir=None, ef=None, host=None, port=None):
  8. if ef:
  9. self.ef = ef
  10. else:
  11. self.ef = embedding_functions.OpenAIEmbeddingFunction(
  12. api_key=os.getenv("OPENAI_API_KEY"),
  13. organization_id=os.getenv("OPENAI_ORGANIZATION"),
  14. model_name="text-embedding-ada-002",
  15. )
  16. if host and port:
  17. self.client_settings = chromadb.config.Settings(
  18. chroma_api_impl="rest",
  19. chroma_server_host=host,
  20. chroma_server_http_port=port,
  21. )
  22. else:
  23. if db_dir is None:
  24. db_dir = "db"
  25. self.client_settings = chromadb.config.Settings(
  26. chroma_db_impl="duckdb+parquet",
  27. persist_directory=db_dir,
  28. anonymized_telemetry=False,
  29. )
  30. super().__init__()
  31. def _get_or_create_db(self):
  32. """Get or create the database."""
  33. return chromadb.Client(self.client_settings)
  34. def _get_or_create_collection(self):
  35. """Get or create the collection."""
  36. return self.client.get_or_create_collection(
  37. "embedchain_store",
  38. embedding_function=self.ef,
  39. )