chroma.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import logging
  2. from typing import Any, Dict, List, Optional
  3. from chromadb import Collection, QueryResult
  4. from langchain.docstore.document import Document
  5. from embedchain.config import ChromaDbConfig
  6. from embedchain.helper.json_serializable import register_deserializable
  7. from embedchain.vectordb.base import BaseVectorDB
  8. try:
  9. import chromadb
  10. from chromadb.config import Settings
  11. from chromadb.errors import InvalidDimensionException
  12. except RuntimeError:
  13. from embedchain.utils import use_pysqlite3
  14. use_pysqlite3()
  15. import chromadb
  16. from chromadb.config import Settings
  17. from chromadb.errors import InvalidDimensionException
  18. @register_deserializable
  19. class ChromaDB(BaseVectorDB):
  20. """Vector database using ChromaDB."""
  21. def __init__(self, config: Optional[ChromaDbConfig] = None):
  22. """Initialize a new ChromaDB instance
  23. :param config: Configuration options for Chroma, defaults to None
  24. :type config: Optional[ChromaDbConfig], optional
  25. """
  26. if config:
  27. self.config = config
  28. else:
  29. self.config = ChromaDbConfig()
  30. self.settings = Settings()
  31. if self.config.chroma_settings:
  32. for key, value in self.config.chroma_settings.items():
  33. if hasattr(self.settings, key):
  34. setattr(self.settings, key, value)
  35. if self.config.host and self.config.port:
  36. logging.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
  37. self.settings.chroma_server_host = self.config.host
  38. self.settings.chroma_server_http_port = self.config.port
  39. self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
  40. else:
  41. if self.config.dir is None:
  42. self.config.dir = "db"
  43. self.settings.persist_directory = self.config.dir
  44. self.settings.is_persistent = True
  45. self.client = chromadb.Client(self.settings)
  46. super().__init__(config=self.config)
  47. def _initialize(self):
  48. """
  49. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  50. """
  51. if not self.embedder:
  52. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  53. self._get_or_create_collection(self.config.collection_name)
  54. def _get_or_create_db(self):
  55. """Called during initialization"""
  56. return self.client
  57. def _get_or_create_collection(self, name: str) -> Collection:
  58. """
  59. Get or create a named collection.
  60. :param name: Name of the collection
  61. :type name: str
  62. :raises ValueError: No embedder configured.
  63. :return: Created collection
  64. :rtype: Collection
  65. """
  66. if not hasattr(self, "embedder") or not self.embedder:
  67. raise ValueError("Cannot create a Chroma database collection without an embedder.")
  68. self.collection = self.client.get_or_create_collection(
  69. name=name,
  70. embedding_function=self.embedder.embedding_fn,
  71. )
  72. return self.collection
  73. def get(self, ids=None, where=None, limit=None):
  74. """
  75. Get existing doc ids present in vector database
  76. :param ids: list of doc ids to check for existence
  77. :type ids: List[str]
  78. :param where: Optional. to filter data
  79. :type where: Dict[str, Any]
  80. :return: Existing documents.
  81. :rtype: List[str]
  82. """
  83. args = {}
  84. if ids:
  85. args["ids"] = ids
  86. if where:
  87. args["where"] = where
  88. if limit:
  89. args["limit"] = limit
  90. return self.collection.get(**args)
  91. def get_advanced(self, where):
  92. return self.collection.get(where=where, limit=1)
  93. def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any:
  94. """
  95. Add vectors to chroma database
  96. :param documents: Documents
  97. :type documents: List[str]
  98. :param metadatas: Metadatas
  99. :type metadatas: List[object]
  100. :param ids: ids
  101. :type ids: List[str]
  102. """
  103. self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
  104. def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
  105. """
  106. Format Chroma results
  107. :param results: ChromaDB query results to format.
  108. :type results: QueryResult
  109. :return: Formatted results
  110. :rtype: list[tuple[Document, float]]
  111. """
  112. return [
  113. (Document(page_content=result[0], metadata=result[1] or {}), result[2])
  114. for result in zip(
  115. results["documents"][0],
  116. results["metadatas"][0],
  117. results["distances"][0],
  118. )
  119. ]
  120. def query(self, input_query: List[str], n_results: int, where: Dict[str, Any]) -> List[str]:
  121. """
  122. Query contents from vector data base based on vector similarity
  123. :param input_query: list of query string
  124. :type input_query: List[str]
  125. :param n_results: no of similar documents to fetch from database
  126. :type n_results: int
  127. :param where: to filter data
  128. :type where: Dict[str, Any]
  129. :raises InvalidDimensionException: Dimensions do not match.
  130. :return: The content of the document that matched your query.
  131. :rtype: List[str]
  132. """
  133. try:
  134. result = self.collection.query(
  135. query_texts=[
  136. input_query,
  137. ],
  138. n_results=n_results,
  139. where=where,
  140. )
  141. except InvalidDimensionException as e:
  142. raise InvalidDimensionException(
  143. e.message()
  144. + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501
  145. ) from None
  146. results_formatted = self._format_result(result)
  147. contents = [result[0].page_content for result in results_formatted]
  148. return contents
  149. def set_collection_name(self, name: str):
  150. """
  151. Set the name of the collection. A collection is an isolated space for vectors.
  152. :param name: Name of the collection.
  153. :type name: str
  154. """
  155. self.config.collection_name = name
  156. self._get_or_create_collection(self.config.collection_name)
  157. def count(self) -> int:
  158. """
  159. Count number of documents/chunks embedded in the database.
  160. :return: number of documents
  161. :rtype: int
  162. """
  163. return self.collection.count()
  164. def delete(self, where):
  165. return self.collection.delete(where=where)
  166. def reset(self):
  167. """
  168. Resets the database. Deletes all embeddings irreversibly.
  169. """
  170. # Delete all data from the database
  171. try:
  172. self.client.reset()
  173. except ValueError:
  174. raise ValueError(
  175. "For safety reasons, resetting is disabled."
  176. 'Please enable it by including `chromadb_settings={"allow_reset": True}` in your ChromaDbConfig'
  177. ) from None
  178. # Recreate
  179. self._get_or_create_collection(self.config.collection_name)
  180. # Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
  181. # A downside of this implementation is, if you have two instances,
  182. # the other instance will not get the updated `self.collection` attribute.
  183. # A better way would be to create the collection if it is called again after being reset.
  184. # That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
  185. # That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.