chroma_db.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import logging
  2. from typing import Any, Dict, List
  3. from chromadb.errors import InvalidDimensionException
  4. from langchain.docstore.document import Document
  5. try:
  6. import chromadb
  7. except RuntimeError:
  8. from embedchain.utils import use_pysqlite3
  9. use_pysqlite3()
  10. import chromadb
  11. from chromadb.config import Settings
  12. from embedchain.vectordb.base_vector_db import BaseVectorDB
  13. class ChromaDB(BaseVectorDB):
  14. """Vector database using ChromaDB."""
  15. def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
  16. self.embedding_fn = embedding_fn
  17. if not hasattr(embedding_fn, "__call__"):
  18. raise ValueError("Embedding function is not a function")
  19. if host and port:
  20. logging.info(f"Connecting to ChromaDB server: {host}:{port}")
  21. self.settings = Settings(chroma_server_host=host, chroma_server_http_port=port)
  22. self.client = chromadb.HttpClient(self.settings)
  23. else:
  24. if db_dir is None:
  25. db_dir = "db"
  26. self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
  27. self.client = chromadb.PersistentClient(
  28. path=db_dir,
  29. settings=self.settings,
  30. )
  31. super().__init__()
  32. def _get_or_create_db(self):
  33. """Get or create the database."""
  34. return self.client
  35. def _get_or_create_collection(self, name):
  36. """Get or create the collection."""
  37. self.collection = self.client.get_or_create_collection(
  38. name=name,
  39. embedding_function=self.embedding_fn,
  40. )
  41. return self.collection
  42. def get(self, ids: List[str], where: Dict[str, any]) -> List[str]:
  43. """
  44. Get existing doc ids present in vector database
  45. :param ids: list of doc ids to check for existance
  46. :param where: Optional. to filter data
  47. """
  48. existing_docs = self.collection.get(
  49. ids=ids,
  50. where=where, # optional filter
  51. )
  52. return set(existing_docs["ids"])
  53. def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
  54. """
  55. add data in vector database
  56. :param documents: list of texts to add
  57. :param metadatas: list of metadata associated with docs
  58. :param ids: ids of docs
  59. """
  60. self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
  61. def _format_result(self, results):
  62. return [
  63. (Document(page_content=result[0], metadata=result[1] or {}), result[2])
  64. for result in zip(
  65. results["documents"][0],
  66. results["metadatas"][0],
  67. results["distances"][0],
  68. )
  69. ]
  70. def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]:
  71. """
  72. query contents from vector data base based on vector similarity
  73. :param input_query: list of query string
  74. :param n_results: no of similar documents to fetch from database
  75. :param where: Optional. to filter data
  76. :return: The content of the document that matched your query.
  77. """
  78. try:
  79. result = self.collection.query(
  80. query_texts=[
  81. input_query,
  82. ],
  83. n_results=n_results,
  84. where=where,
  85. )
  86. except InvalidDimensionException as e:
  87. raise InvalidDimensionException(
  88. e.message()
  89. + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
  90. ) from None
  91. results_formatted = self._format_result(result)
  92. contents = [result[0].page_content for result in results_formatted]
  93. return contents
  94. def count(self) -> int:
  95. return self.collection.count()
  96. def reset(self):
  97. # Delete all data from the database
  98. self.client.reset()