chroma_db.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import logging
  2. from typing import Any, Dict, List
  3. from chromadb.errors import InvalidDimensionException
  4. from langchain.docstore.document import Document
  5. try:
  6. import chromadb
  7. except RuntimeError:
  8. from embedchain.utils import use_pysqlite3
  9. use_pysqlite3()
  10. import chromadb
  11. from chromadb.config import Settings
  12. from embedchain.helper_classes.json_serializable import register_deserializable
  13. from embedchain.vectordb.base_vector_db import BaseVectorDB
  14. @register_deserializable
  15. class ChromaDB(BaseVectorDB):
  16. """Vector database using ChromaDB."""
  17. def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
  18. self.embedding_fn = embedding_fn
  19. if not hasattr(embedding_fn, "__call__"):
  20. raise ValueError("Embedding function is not a function")
  21. if host and port:
  22. logging.info(f"Connecting to ChromaDB server: {host}:{port}")
  23. self.client = chromadb.HttpClient(host=host, port=port)
  24. else:
  25. if db_dir is None:
  26. db_dir = "db"
  27. self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
  28. self.client = chromadb.PersistentClient(
  29. path=db_dir,
  30. settings=self.settings,
  31. )
  32. super().__init__()
  33. def _get_or_create_db(self):
  34. """Get or create the database."""
  35. return self.client
  36. def _get_or_create_collection(self, name):
  37. """Get or create the collection."""
  38. self.collection = self.client.get_or_create_collection(
  39. name=name,
  40. embedding_function=self.embedding_fn,
  41. )
  42. return self.collection
  43. def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
  44. """
  45. Get existing doc ids present in vector database
  46. :param ids: list of doc ids to check for existance
  47. :param where: Optional. to filter data
  48. """
  49. existing_docs = self.collection.get(
  50. ids=ids,
  51. where=where, # optional filter
  52. )
  53. return set(existing_docs["ids"])
  54. def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
  55. """
  56. add data in vector database
  57. :param documents: list of texts to add
  58. :param metadatas: list of metadata associated with docs
  59. :param ids: ids of docs
  60. """
  61. self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
  62. def _format_result(self, results):
  63. return [
  64. (Document(page_content=result[0], metadata=result[1] or {}), result[2])
  65. for result in zip(
  66. results["documents"][0],
  67. results["metadatas"][0],
  68. results["distances"][0],
  69. )
  70. ]
  71. def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
  72. """
  73. query contents from vector data base based on vector similarity
  74. :param input_query: list of query string
  75. :param n_results: no of similar documents to fetch from database
  76. :param where: Optional. to filter data
  77. :return: The content of the document that matched your query.
  78. """
  79. try:
  80. result = self.collection.query(
  81. query_texts=[
  82. input_query,
  83. ],
  84. n_results=n_results,
  85. where=where,
  86. )
  87. except InvalidDimensionException as e:
  88. raise InvalidDimensionException(
  89. e.message()
  90. + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
  91. ) from None
  92. results_formatted = self._format_result(result)
  93. contents = [result[0].page_content for result in results_formatted]
  94. return contents
  95. def count(self) -> int:
  96. return self.collection.count()
  97. def reset(self):
  98. # Delete all data from the database
  99. self.client.reset()