qdrant.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import copy
  2. import os
  3. import uuid
  4. from typing import Any, Dict, List, Optional, Tuple, Union
  5. try:
  6. from qdrant_client import QdrantClient
  7. from qdrant_client.http import models
  8. from qdrant_client.http.models import Batch
  9. from qdrant_client.models import Distance, VectorParams
  10. except ImportError:
  11. raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
  12. from embedchain.config.vectordb.qdrant import QdrantDBConfig
  13. from embedchain.vectordb.base import BaseVectorDB
  14. class QdrantDB(BaseVectorDB):
  15. """
  16. Qdrant as vector database
  17. """
  18. BATCH_SIZE = 10
  19. def __init__(self, config: QdrantDBConfig = None):
  20. """
  21. Qdrant as vector database
  22. :param config. Qdrant database config to be used for connection
  23. """
  24. if config is None:
  25. config = QdrantDBConfig()
  26. else:
  27. if not isinstance(config, QdrantDBConfig):
  28. raise TypeError(
  29. "config is not a `QdrantDBConfig` instance. "
  30. "Please make sure the type is right and that you are passing an instance."
  31. )
  32. self.config = config
  33. self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
  34. # Call parent init here because embedder is needed
  35. super().__init__(config=self.config)
  36. def _initialize(self):
  37. """
  38. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  39. """
  40. if not self.embedder:
  41. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  42. self.collection_name = self._get_or_create_collection()
  43. self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
  44. all_collections = self.client.get_collections()
  45. collection_names = [collection.name for collection in all_collections.collections]
  46. if self.collection_name not in collection_names:
  47. self.client.recreate_collection(
  48. collection_name=self.collection_name,
  49. vectors_config=VectorParams(
  50. size=self.embedder.vector_dimension,
  51. distance=Distance.COSINE,
  52. hnsw_config=self.config.hnsw_config,
  53. quantization_config=self.config.quantization_config,
  54. on_disk=self.config.on_disk,
  55. ),
  56. )
  57. def _get_or_create_db(self):
  58. return self.client
  59. def _get_or_create_collection(self):
  60. return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
  61. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  62. """
  63. Get existing doc ids present in vector database
  64. :param ids: _list of doc ids to check for existence
  65. :type ids: List[str]
  66. :param where: to filter data
  67. :type where: Dict[str, any]
  68. :param limit: The number of entries to be fetched
  69. :type limit: Optional int, defaults to None
  70. :return: All the existing IDs
  71. :rtype: Set[str]
  72. """
  73. if ids is None or len(ids) == 0:
  74. return {"ids": []}
  75. keys = set(where.keys() if where is not None else set())
  76. qdrant_must_filters = [
  77. models.FieldCondition(
  78. key="identifier",
  79. match=models.MatchAny(
  80. any=ids,
  81. ),
  82. )
  83. ]
  84. if len(keys.intersection(self.metadata_keys)) != 0:
  85. for key in keys.intersection(self.metadata_keys):
  86. qdrant_must_filters.append(
  87. models.FieldCondition(
  88. key="metadata.{}".format(key),
  89. match=models.MatchValue(
  90. value=where.get(key),
  91. ),
  92. )
  93. )
  94. offset = 0
  95. existing_ids = []
  96. while offset is not None:
  97. response = self.client.scroll(
  98. collection_name=self.collection_name,
  99. scroll_filter=models.Filter(must=qdrant_must_filters),
  100. offset=offset,
  101. limit=self.BATCH_SIZE,
  102. )
  103. offset = response[1]
  104. for doc in response[0]:
  105. existing_ids.append(doc.payload["identifier"])
  106. return {"ids": existing_ids}
  107. def add(
  108. self,
  109. embeddings: List[List[float]],
  110. documents: List[str],
  111. metadatas: List[object],
  112. ids: List[str],
  113. skip_embedding: bool,
  114. **kwargs: Optional[Dict[str, any]],
  115. ):
  116. """add data in vector database
  117. :param embeddings: list of embeddings for the corresponding documents to be added
  118. :type documents: List[List[float]]
  119. :param documents: list of texts to add
  120. :type documents: List[str]
  121. :param metadatas: list of metadata associated with docs
  122. :type metadatas: List[object]
  123. :param ids: ids of docs
  124. :type ids: List[str]
  125. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  126. generated or not
  127. :type skip_embedding: bool
  128. """
  129. if not skip_embedding:
  130. embeddings = self.embedder.embedding_fn(documents)
  131. payloads = []
  132. qdrant_ids = []
  133. for id, document, metadata in zip(ids, documents, metadatas):
  134. metadata["text"] = document
  135. qdrant_ids.append(str(uuid.uuid4()))
  136. payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
  137. for i in range(0, len(qdrant_ids), self.BATCH_SIZE):
  138. self.client.upsert(
  139. collection_name=self.collection_name,
  140. points=Batch(
  141. ids=qdrant_ids[i : i + self.BATCH_SIZE],
  142. payloads=payloads[i : i + self.BATCH_SIZE],
  143. vectors=embeddings[i : i + self.BATCH_SIZE],
  144. ),
  145. **kwargs,
  146. )
  147. def query(
  148. self,
  149. input_query: List[str],
  150. n_results: int,
  151. where: Dict[str, any],
  152. skip_embedding: bool,
  153. citations: bool = False,
  154. **kwargs: Optional[Dict[str, Any]],
  155. ) -> Union[List[Tuple[str, Dict]], List[str]]:
  156. """
  157. query contents from vector database based on vector similarity
  158. :param input_query: list of query string
  159. :type input_query: List[str]
  160. :param n_results: no of similar documents to fetch from database
  161. :type n_results: int
  162. :param where: Optional. to filter data
  163. :type where: Dict[str, any]
  164. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  165. generated or not
  166. :type skip_embedding: bool
  167. :param citations: we use citations boolean param to return context along with the answer.
  168. :type citations: bool, default is False.
  169. :return: The content of the document that matched your query,
  170. along with url of the source and doc_id (if citations flag is true)
  171. :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
  172. """
  173. if not skip_embedding:
  174. query_vector = self.embedder.embedding_fn([input_query])[0]
  175. else:
  176. query_vector = input_query
  177. keys = set(where.keys() if where is not None else set())
  178. qdrant_must_filters = []
  179. if len(keys.intersection(self.metadata_keys)) != 0:
  180. for key in keys.intersection(self.metadata_keys):
  181. qdrant_must_filters.append(
  182. models.FieldCondition(
  183. key="payload.metadata.{}".format(key),
  184. match=models.MatchValue(
  185. value=where.get(key),
  186. ),
  187. )
  188. )
  189. results = self.client.search(
  190. collection_name=self.collection_name,
  191. query_filter=models.Filter(must=qdrant_must_filters),
  192. query_vector=query_vector,
  193. limit=n_results,
  194. **kwargs,
  195. )
  196. contexts = []
  197. for result in results:
  198. context = result.payload["text"]
  199. if citations:
  200. metadata = result.payload["metadata"]
  201. metadata["score"] = result.score
  202. contexts.append(tuple((context, metadata)))
  203. else:
  204. contexts.append(context)
  205. return contexts
  206. def count(self) -> int:
  207. response = self.client.get_collection(collection_name=self.collection_name)
  208. return response.points_count
  209. def reset(self):
  210. self.client.delete_collection(collection_name=self.collection_name)
  211. self._initialize()
  212. def set_collection_name(self, name: str):
  213. """
  214. Set the name of the collection. A collection is an isolated space for vectors.
  215. :param name: Name of the collection.
  216. :type name: str
  217. """
  218. if not isinstance(name, str):
  219. raise TypeError("Collection name must be a string")
  220. self.config.collection_name = name
  221. self.collection_name = self._get_or_create_collection()