123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- import copy
- import os
- from typing import Any, Optional, Union
- try:
- from qdrant_client import QdrantClient
- from qdrant_client.http import models
- from qdrant_client.http.models import Batch
- from qdrant_client.models import Distance, VectorParams
- except ImportError:
- raise ImportError("Qdrant requires extra dependencies. Install with `pip install embedchain[qdrant]`") from None
- from tqdm import tqdm
- from embedchain.config.vectordb.qdrant import QdrantDBConfig
- from embedchain.vectordb.base import BaseVectorDB
- class QdrantDB(BaseVectorDB):
- """
- Qdrant as vector database
- """
- BATCH_SIZE = 10
- def __init__(self, config: QdrantDBConfig = None):
- """
- Qdrant as vector database
- :param config. Qdrant database config to be used for connection
- """
- if config is None:
- config = QdrantDBConfig()
- else:
- if not isinstance(config, QdrantDBConfig):
- raise TypeError(
- "config is not a `QdrantDBConfig` instance. "
- "Please make sure the type is right and that you are passing an instance."
- )
- self.config = config
- self.client = QdrantClient(url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"))
- # Call parent init here because embedder is needed
- super().__init__(config=self.config)
- def _initialize(self):
- """
- This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
- """
- if not self.embedder:
- raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
- self.collection_name = self._get_or_create_collection()
- all_collections = self.client.get_collections()
- collection_names = [collection.name for collection in all_collections.collections]
- if self.collection_name not in collection_names:
- self.client.recreate_collection(
- collection_name=self.collection_name,
- vectors_config=VectorParams(
- size=self.embedder.vector_dimension,
- distance=Distance.COSINE,
- hnsw_config=self.config.hnsw_config,
- quantization_config=self.config.quantization_config,
- on_disk=self.config.on_disk,
- ),
- )
- def _get_or_create_db(self):
- return self.client
- def _get_or_create_collection(self):
- return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
- def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
- """
- Get existing doc ids present in vector database
- :param ids: _list of doc ids to check for existence
- :type ids: list[str]
- :param where: to filter data
- :type where: dict[str, any]
- :param limit: The number of entries to be fetched
- :type limit: Optional int, defaults to None
- :return: All the existing IDs
- :rtype: Set[str]
- """
- keys = set(where.keys() if where is not None else set())
- qdrant_must_filters = []
- if ids:
- qdrant_must_filters.append(
- models.FieldCondition(
- key="identifier",
- match=models.MatchAny(
- any=ids,
- ),
- )
- )
- if len(keys) > 0:
- for key in keys:
- qdrant_must_filters.append(
- models.FieldCondition(
- key="metadata.{}".format(key),
- match=models.MatchValue(
- value=where.get(key),
- ),
- )
- )
- offset = 0
- existing_ids = []
- metadatas = []
- while offset is not None:
- response = self.client.scroll(
- collection_name=self.collection_name,
- scroll_filter=models.Filter(must=qdrant_must_filters),
- offset=offset,
- limit=self.BATCH_SIZE,
- )
- offset = response[1]
- for doc in response[0]:
- existing_ids.append(doc.payload["identifier"])
- metadatas.append(doc.payload["metadata"])
- return {"ids": existing_ids, "metadatas": metadatas}
- def add(
- self,
- documents: list[str],
- metadatas: list[object],
- ids: list[str],
- **kwargs: Optional[dict[str, any]],
- ):
- """add data in vector database
- :param documents: list of texts to add
- :type documents: list[str]
- :param metadatas: list of metadata associated with docs
- :type metadatas: list[object]
- :param ids: ids of docs
- :type ids: list[str]
- """
- embeddings = self.embedder.embedding_fn(documents)
- payloads = []
- qdrant_ids = []
- for id, document, metadata in zip(ids, documents, metadatas):
- metadata["text"] = document
- qdrant_ids.append(id)
- payloads.append({"identifier": id, "text": document, "metadata": copy.deepcopy(metadata)})
- for i in tqdm(range(0, len(qdrant_ids), self.BATCH_SIZE), desc="Adding data in batches"):
- self.client.upsert(
- collection_name=self.collection_name,
- points=Batch(
- ids=qdrant_ids[i : i + self.BATCH_SIZE],
- payloads=payloads[i : i + self.BATCH_SIZE],
- vectors=embeddings[i : i + self.BATCH_SIZE],
- ),
- **kwargs,
- )
- def query(
- self,
- input_query: str,
- n_results: int,
- where: dict[str, any],
- citations: bool = False,
- **kwargs: Optional[dict[str, Any]],
- ) -> Union[list[tuple[str, dict]], list[str]]:
- """
- query contents from vector database based on vector similarity
- :param input_query: query string
- :type input_query: str
- :param n_results: no of similar documents to fetch from database
- :type n_results: int
- :param where: Optional. to filter data
- :type where: dict[str, any]
- :param citations: we use citations boolean param to return context along with the answer.
- :type citations: bool, default is False.
- :return: The content of the document that matched your query,
- along with url of the source and doc_id (if citations flag is true)
- :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
- """
- query_vector = self.embedder.embedding_fn([input_query])[0]
- keys = set(where.keys() if where is not None else set())
- qdrant_must_filters = []
- if len(keys) > 0:
- for key in keys:
- qdrant_must_filters.append(
- models.FieldCondition(
- key="metadata.{}".format(key),
- match=models.MatchValue(
- value=where.get(key),
- ),
- )
- )
- results = self.client.search(
- collection_name=self.collection_name,
- query_filter=models.Filter(must=qdrant_must_filters),
- query_vector=query_vector,
- limit=n_results,
- **kwargs,
- )
- contexts = []
- for result in results:
- context = result.payload["text"]
- if citations:
- metadata = result.payload["metadata"]
- metadata["score"] = result.score
- contexts.append(tuple((context, metadata)))
- else:
- contexts.append(context)
- return contexts
- def count(self) -> int:
- response = self.client.get_collection(collection_name=self.collection_name)
- return response.points_count
- def reset(self):
- self.client.delete_collection(collection_name=self.collection_name)
- self._initialize()
- def set_collection_name(self, name: str):
- """
- Set the name of the collection. A collection is an isolated space for vectors.
- :param name: Name of the collection.
- :type name: str
- """
- if not isinstance(name, str):
- raise TypeError("Collection name must be a string")
- self.config.collection_name = name
- self.collection_name = self._get_or_create_collection()
- @staticmethod
- def _generate_query(where: dict):
- must_fields = []
- for key, value in where.items():
- must_fields.append(
- models.FieldCondition(
- key=f"metadata.{key}",
- match=models.MatchValue(
- value=value,
- ),
- )
- )
- return models.Filter(must=must_fields)
- def delete(self, where: dict):
- db_filter = self._generate_query(where)
- self.client.delete(collection_name=self.collection_name, points_selector=db_filter)
|