qdrant.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import copy
  2. import os
  3. from typing import Any, Optional, Union
  4. try:
  5. from qdrant_client import QdrantClient
  6. from qdrant_client.http import models
  7. from qdrant_client.http.models import Batch
  8. from qdrant_client.models import Distance, VectorParams
  9. except ImportError:
  10. raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
  11. from tqdm import tqdm
  12. from embedchain.config.vectordb.qdrant import QdrantDBConfig
  13. from embedchain.vectordb.base import BaseVectorDB
  14. class QdrantDB(BaseVectorDB):
  15. """
  16. Qdrant as vector database
  17. """
  18. def __init__(self, config: QdrantDBConfig = None):
  19. """
  20. Qdrant as vector database
  21. :param config. Qdrant database config to be used for connection
  22. """
  23. if config is None:
  24. config = QdrantDBConfig()
  25. else:
  26. if not isinstance(config, QdrantDBConfig):
  27. raise TypeError(
  28. "config is not a `QdrantDBConfig` instance. "
  29. "Please make sure the type is right and that you are passing an instance."
  30. )
  31. self.config = config
  32. self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
  33. # Call parent init here because embedder is needed
  34. super().__init__(config=self.config)
  35. def _initialize(self):
  36. """
  37. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  38. """
  39. if not self.embedder:
  40. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  41. self.collection_name = self._get_or_create_collection()
  42. all_collections = self.client.get_collections()
  43. collection_names = [collection.name for collection in all_collections.collections]
  44. if self.collection_name not in collection_names:
  45. self.client.recreate_collection(
  46. collection_name=self.collection_name,
  47. vectors_config=VectorParams(
  48. size=self.embedder.vector_dimension,
  49. distance=Distance.COSINE,
  50. hnsw_config=self.config.hnsw_config,
  51. quantization_config=self.config.quantization_config,
  52. on_disk=self.config.on_disk,
  53. ),
  54. )
  55. def _get_or_create_db(self):
  56. return self.client
  57. def _get_or_create_collection(self):
  58. return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
  59. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  60. """
  61. Get existing doc ids present in vector database
  62. :param ids: _list of doc ids to check for existence
  63. :type ids: list[str]
  64. :param where: to filter data
  65. :type where: dict[str, any]
  66. :param limit: The number of entries to be fetched
  67. :type limit: Optional int, defaults to None
  68. :return: All the existing IDs
  69. :rtype: Set[str]
  70. """
  71. keys = set(where.keys() if where is not None else set())
  72. qdrant_must_filters = []
  73. if ids:
  74. qdrant_must_filters.append(
  75. models.FieldCondition(
  76. key="identifier",
  77. match=models.MatchAny(
  78. any=ids,
  79. ),
  80. )
  81. )
  82. if len(keys) > 0:
  83. for key in keys:
  84. qdrant_must_filters.append(
  85. models.FieldCondition(
  86. key="metadata.{}".format(key),
  87. match=models.MatchValue(
  88. value=where.get(key),
  89. ),
  90. )
  91. )
  92. offset = 0
  93. existing_ids = []
  94. metadatas = []
  95. while offset is not None:
  96. response = self.client.scroll(
  97. collection_name=self.collection_name,
  98. scroll_filter=models.Filter(must=qdrant_must_filters),
  99. offset=offset,
  100. limit=self.config.batch_size,
  101. )
  102. offset = response[1]
  103. for doc in response[0]:
  104. existing_ids.append(doc.payload["identifier"])
  105. metadatas.append(doc.payload["metadata"])
  106. return {"ids": existing_ids, "metadatas": metadatas}
  107. def add(
  108. self,
  109. documents: list[str],
  110. metadatas: list[object],
  111. ids: list[str],
  112. **kwargs: Optional[dict[str, any]],
  113. ):
  114. """add data in vector database
  115. :param documents: list of texts to add
  116. :type documents: list[str]
  117. :param metadatas: list of metadata associated with docs
  118. :type metadatas: list[object]
  119. :param ids: ids of docs
  120. :type ids: list[str]
  121. """
  122. embeddings = self.embedder.embedding_fn(documents)
  123. payloads = []
  124. qdrant_ids = []
  125. for id, document, metadata in zip(ids, documents, metadatas):
  126. metadata["text"] = document
  127. qdrant_ids.append(id)
  128. payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
  129. for i in tqdm(range(0, len(qdrant_ids), self.config.batch_size), desc="Adding data in batches"):
  130. self.client.upsert(
  131. collection_name=self.collection_name,
  132. points=Batch(
  133. ids=qdrant_ids[i : i + self.config.batch_size],
  134. payloads=payloads[i : i + self.config.batch_size],
  135. vectors=embeddings[i : i + self.config.batch_size],
  136. ),
  137. **kwargs,
  138. )
  139. def query(
  140. self,
  141. input_query: str,
  142. n_results: int,
  143. where: dict[str, any],
  144. citations: bool = False,
  145. **kwargs: Optional[dict[str, Any]],
  146. ) -> Union[list[tuple[str, dict]], list[str]]:
  147. """
  148. query contents from vector database based on vector similarity
  149. :param input_query: query string
  150. :type input_query: str
  151. :param n_results: no of similar documents to fetch from database
  152. :type n_results: int
  153. :param where: Optional. to filter data
  154. :type where: dict[str, any]
  155. :param citations: we use citations boolean param to return context along with the answer.
  156. :type citations: bool, default is False.
  157. :return: The content of the document that matched your query,
  158. along with url of the source and doc_id (if citations flag is true)
  159. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  160. """
  161. query_vector = self.embedder.embedding_fn([input_query])[0]
  162. keys = set(where.keys() if where is not None else set())
  163. qdrant_must_filters = []
  164. if len(keys) > 0:
  165. for key in keys:
  166. qdrant_must_filters.append(
  167. models.FieldCondition(
  168. key="metadata.{}".format(key),
  169. match=models.MatchValue(
  170. value=where.get(key),
  171. ),
  172. )
  173. )
  174. results = self.client.search(
  175. collection_name=self.collection_name,
  176. query_filter=models.Filter(must=qdrant_must_filters),
  177. query_vector=query_vector,
  178. limit=n_results,
  179. **kwargs,
  180. )
  181. contexts = []
  182. for result in results:
  183. context = result.payload["text"]
  184. if citations:
  185. metadata = result.payload["metadata"]
  186. metadata["score"] = result.score
  187. contexts.append(tuple((context, metadata)))
  188. else:
  189. contexts.append(context)
  190. return contexts
  191. def count(self) -> int:
  192. response = self.client.get_collection(collection_name=self.collection_name)
  193. return response.points_count
  194. def reset(self):
  195. self.client.delete_collection(collection_name=self.collection_name)
  196. self._initialize()
  197. def set_collection_name(self, name: str):
  198. """
  199. Set the name of the collection. A collection is an isolated space for vectors.
  200. :param name: Name of the collection.
  201. :type name: str
  202. """
  203. if not isinstance(name, str):
  204. raise TypeError("Collection name must be a string")
  205. self.config.collection_name = name
  206. self.collection_name = self._get_or_create_collection()
  207. @staticmethod
  208. def _generate_query(where: dict):
  209. must_fields = []
  210. for key, value in where.items():
  211. must_fields.append(
  212. models.FieldCondition(
  213. key=f"metadata.{key}",
  214. match=models.MatchValue(
  215. value=value,
  216. ),
  217. )
  218. )
  219. return models.Filter(must=must_fields)
  220. def delete(self, where: dict):
  221. db_filter = self._generate_query(where)
  222. self.client.delete(collection_name=self.collection_name, points_selector=db_filter)