chroma_db.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import logging
  2. from typing import Any, Dict, List
  3. from chromadb.errors import InvalidDimensionException
  4. from langchain.docstore.document import Document
  5. try:
  6. import chromadb
  7. except RuntimeError:
  8. from embedchain.utils import use_pysqlite3
  9. use_pysqlite3()
  10. import chromadb
  11. from chromadb.config import Settings
  12. from embedchain.vectordb.base_vector_db import BaseVectorDB
  13. class ChromaDB(BaseVectorDB):
  14. """Vector database using ChromaDB."""
  15. def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
  16. self.embedding_fn = embedding_fn
  17. if not hasattr(embedding_fn, "__call__"):
  18. raise ValueError("Embedding function is not a function")
  19. if host and port:
  20. logging.info(f"Connecting to ChromaDB server: {host}:{port}")
  21. self.client = chromadb.HttpClient(host=host, port=8000)
  22. else:
  23. if db_dir is None:
  24. db_dir = "db"
  25. self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
  26. self.client = chromadb.PersistentClient(
  27. path=db_dir,
  28. settings=self.settings,
  29. )
  30. super().__init__()
  31. def _get_or_create_db(self):
  32. """Get or create the database."""
  33. return self.client
  34. def _get_or_create_collection(self, name):
  35. """Get or create the collection."""
  36. self.collection = self.client.get_or_create_collection(
  37. name=name,
  38. embedding_function=self.embedding_fn,
  39. )
  40. return self.collection
  41. def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
  42. """
  43. Get existing doc ids present in vector database
  44. :param ids: list of doc ids to check for existance
  45. :param where: Optional. to filter data
  46. """
  47. existing_docs = self.collection.get(
  48. ids=ids,
  49. where=where, # optional filter
  50. )
  51. return set(existing_docs["ids"])
  52. def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
  53. """
  54. add data in vector database
  55. :param documents: list of texts to add
  56. :param metadatas: list of metadata associated with docs
  57. :param ids: ids of docs
  58. """
  59. self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
  60. def _format_result(self, results):
  61. return [
  62. (Document(page_content=result[0], metadata=result[1] or {}), result[2])
  63. for result in zip(
  64. results["documents"][0],
  65. results["metadatas"][0],
  66. results["distances"][0],
  67. )
  68. ]
  69. def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
  70. """
  71. query contents from vector data base based on vector similarity
  72. :param input_query: list of query string
  73. :param n_results: no of similar documents to fetch from database
  74. :param where: Optional. to filter data
  75. :return: The content of the document that matched your query.
  76. """
  77. try:
  78. result = self.collection.query(
  79. query_texts=[
  80. input_query,
  81. ],
  82. n_results=n_results,
  83. where=where,
  84. )
  85. except InvalidDimensionException as e:
  86. raise InvalidDimensionException(
  87. e.message()
  88. + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
  89. ) from None
  90. results_formatted = self._format_result(result)
  91. contents = [result[0].page_content for result in results_formatted]
  92. return contents
  93. def count(self) -> int:
  94. return self.collection.count()
  95. def reset(self):
  96. # Delete all data from the database
  97. self.client.reset()