chroma.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import logging
  2. from typing import Any, Optional, Union
  3. from chromadb import Collection, QueryResult
  4. from langchain.docstore.document import Document
  5. from tqdm import tqdm
  6. from embedchain.config import ChromaDbConfig
  7. from embedchain.helpers.json_serializable import register_deserializable
  8. from embedchain.vectordb.base import BaseVectorDB
  9. try:
  10. import chromadb
  11. from chromadb.config import Settings
  12. from chromadb.errors import InvalidDimensionException
  13. except RuntimeError:
  14. from embedchain.utils.misc import use_pysqlite3
  15. use_pysqlite3()
  16. import chromadb
  17. from chromadb.config import Settings
  18. from chromadb.errors import InvalidDimensionException
  19. logger = logging.getLogger(__name__)
  20. @register_deserializable
  21. class ChromaDB(BaseVectorDB):
  22. """Vector database using ChromaDB."""
  23. BATCH_SIZE = 100
  24. def __init__(self, config: Optional[ChromaDbConfig] = None):
  25. """Initialize a new ChromaDB instance
  26. :param config: Configuration options for Chroma, defaults to None
  27. :type config: Optional[ChromaDbConfig], optional
  28. """
  29. if config:
  30. self.config = config
  31. else:
  32. self.config = ChromaDbConfig()
  33. self.settings = Settings(anonymized_telemetry=False)
  34. self.settings.allow_reset = self.config.allow_reset if hasattr(self.config, "allow_reset") else False
  35. if self.config.chroma_settings:
  36. for key, value in self.config.chroma_settings.items():
  37. if hasattr(self.settings, key):
  38. setattr(self.settings, key, value)
  39. if self.config.host and self.config.port:
  40. logger.info(f"Connecting to ChromaDB server: {self.config.host}:{self.config.port}")
  41. self.settings.chroma_server_host = self.config.host
  42. self.settings.chroma_server_http_port = self.config.port
  43. self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
  44. else:
  45. if self.config.dir is None:
  46. self.config.dir = "db"
  47. self.settings.persist_directory = self.config.dir
  48. self.settings.is_persistent = True
  49. self.client = chromadb.Client(self.settings)
  50. super().__init__(config=self.config)
  51. def _initialize(self):
  52. """
  53. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  54. """
  55. if not self.embedder:
  56. raise ValueError(
  57. "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization."
  58. )
  59. self._get_or_create_collection(self.config.collection_name)
  60. def _get_or_create_db(self):
  61. """Called during initialization"""
  62. return self.client
  63. @staticmethod
  64. def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
  65. # If only one filter is supplied, return it as is
  66. # (no need to wrap in $and based on chroma docs)
  67. if where is None:
  68. return {}
  69. if len(where.keys()) <= 1:
  70. return where
  71. where_filters = []
  72. for k, v in where.items():
  73. if isinstance(v, str):
  74. where_filters.append({k: v})
  75. return {"$and": where_filters}
  76. def _get_or_create_collection(self, name: str) -> Collection:
  77. """
  78. Get or create a named collection.
  79. :param name: Name of the collection
  80. :type name: str
  81. :raises ValueError: No embedder configured.
  82. :return: Created collection
  83. :rtype: Collection
  84. """
  85. if not hasattr(self, "embedder") or not self.embedder:
  86. raise ValueError("Cannot create a Chroma database collection without an embedder.")
  87. self.collection = self.client.get_or_create_collection(
  88. name=name,
  89. embedding_function=self.embedder.embedding_fn,
  90. )
  91. return self.collection
  92. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  93. """
  94. Get existing doc ids present in vector database
  95. :param ids: list of doc ids to check for existence
  96. :type ids: list[str]
  97. :param where: Optional. to filter data
  98. :type where: dict[str, Any]
  99. :param limit: Optional. maximum number of documents
  100. :type limit: Optional[int]
  101. :return: Existing documents.
  102. :rtype: list[str]
  103. """
  104. args = {}
  105. if ids:
  106. args["ids"] = ids
  107. if where:
  108. args["where"] = self._generate_where_clause(where)
  109. if limit:
  110. args["limit"] = limit
  111. return self.collection.get(**args)
  112. def add(
  113. self,
  114. documents: list[str],
  115. metadatas: list[object],
  116. ids: list[str],
  117. ) -> Any:
  118. """
  119. Add vectors to chroma database
  120. :param documents: Documents
  121. :type documents: list[str]
  122. :param metadatas: Metadatas
  123. :type metadatas: list[object]
  124. :param ids: ids
  125. :type ids: list[str]
  126. """
  127. size = len(documents)
  128. if len(documents) != size or len(metadatas) != size or len(ids) != size:
  129. raise ValueError(
  130. "Cannot add documents to chromadb with inconsistent sizes. Documents size: {}, Metadata size: {},"
  131. " Ids size: {}".format(len(documents), len(metadatas), len(ids))
  132. )
  133. for i in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in chromadb"):
  134. self.collection.add(
  135. documents=documents[i : i + self.BATCH_SIZE],
  136. metadatas=metadatas[i : i + self.BATCH_SIZE],
  137. ids=ids[i : i + self.BATCH_SIZE],
  138. )
  139. @staticmethod
  140. def _format_result(results: QueryResult) -> list[tuple[Document, float]]:
  141. """
  142. Format Chroma results
  143. :param results: ChromaDB query results to format.
  144. :type results: QueryResult
  145. :return: Formatted results
  146. :rtype: list[tuple[Document, float]]
  147. """
  148. return [
  149. (Document(page_content=result[0], metadata=result[1] or {}), result[2])
  150. for result in zip(
  151. results["documents"][0],
  152. results["metadatas"][0],
  153. results["distances"][0],
  154. )
  155. ]
  156. def query(
  157. self,
  158. input_query: str,
  159. n_results: int,
  160. where: Optional[dict[str, any]] = None,
  161. raw_filter: Optional[dict[str, any]] = None,
  162. citations: bool = False,
  163. **kwargs: Optional[dict[str, any]],
  164. ) -> Union[list[tuple[str, dict]], list[str]]:
  165. """
  166. Query contents from vector database based on vector similarity
  167. :param input_query: query string
  168. :type input_query: str
  169. :param n_results: no of similar documents to fetch from database
  170. :type n_results: int
  171. :param where: to filter data
  172. :type where: dict[str, Any]
  173. :param raw_filter: Raw filter to apply
  174. :type raw_filter: dict[str, Any]
  175. :param citations: we use citations boolean param to return context along with the answer.
  176. :type citations: bool, default is False.
  177. :raises InvalidDimensionException: Dimensions do not match.
  178. :return: The content of the document that matched your query,
  179. along with url of the source and doc_id (if citations flag is true)
  180. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  181. """
  182. if where and raw_filter:
  183. raise ValueError("Both `where` and `raw_filter` cannot be used together.")
  184. where_clause = {}
  185. if raw_filter:
  186. where_clause = raw_filter
  187. if where:
  188. where_clause = self._generate_where_clause(where)
  189. try:
  190. result = self.collection.query(
  191. query_texts=[
  192. input_query,
  193. ],
  194. n_results=n_results,
  195. where=where_clause,
  196. )
  197. except InvalidDimensionException as e:
  198. raise InvalidDimensionException(
  199. e.message()
  200. + ". This is commonly a side-effect when an embedding function, different from the one used to add the"
  201. " embeddings, is used to retrieve an embedding from the database."
  202. ) from None
  203. results_formatted = self._format_result(result)
  204. contexts = []
  205. for result in results_formatted:
  206. context = result[0].page_content
  207. if citations:
  208. metadata = result[0].metadata
  209. metadata["score"] = result[1]
  210. contexts.append((context, metadata))
  211. else:
  212. contexts.append(context)
  213. return contexts
  214. def set_collection_name(self, name: str):
  215. """
  216. Set the name of the collection. A collection is an isolated space for vectors.
  217. :param name: Name of the collection.
  218. :type name: str
  219. """
  220. if not isinstance(name, str):
  221. raise TypeError("Collection name must be a string")
  222. self.config.collection_name = name
  223. self._get_or_create_collection(self.config.collection_name)
  224. def count(self) -> int:
  225. """
  226. Count number of documents/chunks embedded in the database.
  227. :return: number of documents
  228. :rtype: int
  229. """
  230. return self.collection.count()
  231. def delete(self, where):
  232. return self.collection.delete(where=self._generate_where_clause(where))
  233. def reset(self):
  234. """
  235. Resets the database. Deletes all embeddings irreversibly.
  236. """
  237. # Delete all data from the collection
  238. try:
  239. self.client.delete_collection(self.config.collection_name)
  240. except ValueError:
  241. raise ValueError(
  242. "For safety reasons, resetting is disabled. "
  243. "Please enable it by setting `allow_reset=True` in your ChromaDbConfig"
  244. ) from None
  245. # Recreate
  246. self._get_or_create_collection(self.config.collection_name)
  247. # Todo: Automatically recreating a collection with the same name cannot be the best way to handle a reset.
  248. # A downside of this implementation is, if you have two instances,
  249. # the other instance will not get the updated `self.collection` attribute.
  250. # A better way would be to create the collection if it is called again after being reset.
  251. # That means, checking if collection exists in the db-consuming methods, and creating it if it doesn't.
  252. # That's an extra steps for all uses, just to satisfy a niche use case in a niche method. For now, this will do.