chroma_db.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import logging
  2. from typing import Any, Dict, List
  3. from chromadb.errors import InvalidDimensionException
  4. from langchain.docstore.document import Document
  5. try:
  6. import chromadb
  7. except RuntimeError:
  8. from embedchain.utils import use_pysqlite3
  9. use_pysqlite3()
  10. import chromadb
  11. from chromadb.config import Settings
  12. from embedchain.helper_classes.json_serializable import register_deserializable
  13. from embedchain.vectordb.base_vector_db import BaseVectorDB
  14. @register_deserializable
  15. class ChromaDB(BaseVectorDB):
  16. """Vector database using ChromaDB."""
  17. def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
  18. self.embedding_fn = embedding_fn
  19. if not hasattr(embedding_fn, "__call__"):
  20. raise ValueError("Embedding function is not a function")
  21. self.settings = Settings()
  22. for key, value in chroma_settings.items():
  23. if hasattr(self.settings, key):
  24. setattr(self.settings, key, value)
  25. if host and port:
  26. logging.info(f"Connecting to ChromaDB server: {host}:{port}")
  27. self.settings.chroma_server_host = host
  28. self.settings.chroma_server_http_port = port
  29. self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
  30. else:
  31. if db_dir is None:
  32. db_dir = "db"
  33. self.settings.persist_directory = db_dir
  34. self.settings.is_persistent = True
  35. self.client = chromadb.Client(self.settings)
  36. super().__init__()
  37. def _get_or_create_db(self):
  38. """Get or create the database."""
  39. return self.client
  40. def _get_or_create_collection(self, name):
  41. """Get or create the collection."""
  42. self.collection = self.client.get_or_create_collection(
  43. name=name,
  44. embedding_function=self.embedding_fn,
  45. )
  46. return self.collection
  47. def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
  48. """
  49. Get existing doc ids present in vector database
  50. :param ids: list of doc ids to check for existance
  51. :param where: Optional. to filter data
  52. """
  53. existing_docs = self.collection.get(
  54. ids=ids,
  55. where=where, # optional filter
  56. )
  57. return set(existing_docs["ids"])
  58. def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
  59. """
  60. add data in vector database
  61. :param documents: list of texts to add
  62. :param metadatas: list of metadata associated with docs
  63. :param ids: ids of docs
  64. """
  65. self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
  66. def _format_result(self, results):
  67. return [
  68. (Document(page_content=result[0], metadata=result[1] or {}), result[2])
  69. for result in zip(
  70. results["documents"][0],
  71. results["metadatas"][0],
  72. results["distances"][0],
  73. )
  74. ]
  75. def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
  76. """
  77. query contents from vector data base based on vector similarity
  78. :param input_query: list of query string
  79. :param n_results: no of similar documents to fetch from database
  80. :param where: Optional. to filter data
  81. :return: The content of the document that matched your query.
  82. """
  83. try:
  84. result = self.collection.query(
  85. query_texts=[
  86. input_query,
  87. ],
  88. n_results=n_results,
  89. where=where,
  90. )
  91. except InvalidDimensionException as e:
  92. raise InvalidDimensionException(
  93. e.message()
  94. + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
  95. ) from None
  96. results_formatted = self._format_result(result)
  97. contents = [result[0].page_content for result in results_formatted]
  98. return contents
  99. def count(self) -> int:
  100. return self.collection.count()
  101. def reset(self):
  102. # Delete all data from the database
  103. self.client.reset()