123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- from typing import Any, Dict, List, Optional, Union
- import pyarrow as pa
- try:
- import lancedb
- except ImportError:
- raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
- from embedchain.config.vectordb.lancedb import LanceDBConfig
- from embedchain.helpers.json_serializable import register_deserializable
- from embedchain.vectordb.base import BaseVectorDB
- @register_deserializable
- class LanceDB(BaseVectorDB):
- """
- LanceDB as vector database
- """
- def __init__(
- self,
- config: Optional[LanceDBConfig] = None,
- ):
- """LanceDB as vector database.
- :param config: LanceDB database config, defaults to None
- :type config: LanceDBConfig, optional
- """
- if config:
- self.config = config
- else:
- self.config = LanceDBConfig()
- self.client = lancedb.connect(self.config.dir or "~/.lancedb")
- self.embedder_check = True
- super().__init__(config=self.config)
- def _initialize(self):
- """
- This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
- """
- if not self.embedder:
- raise ValueError(
- "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization."
- )
- else:
- # check embedder function is working or not
- try:
- self.embedder.embedding_fn("Hello LanceDB")
- except Exception:
- self.embedder_check = False
- self._get_or_create_collection(self.config.collection_name)
- def _get_or_create_db(self):
- """
- Called during initialization
- """
- return self.client
- def _generate_where_clause(self, where: Dict[str, any]) -> str:
- """
- This method generate where clause using dictionary containing attributes and their values
- """
- where_filters = ""
- if len(list(where.keys())) == 1:
- where_filters = f"{list(where.keys())[0]} = {list(where.values())[0]}"
- return where_filters
- where_items = list(where.items())
- where_count = len(where_items)
- for i, (key, value) in enumerate(where_items, start=1):
- condition = f"{key} = {value} AND "
- where_filters += condition
- if i == where_count:
- condition = f"{key} = {value}"
- where_filters += condition
- return where_filters
- def _get_or_create_collection(self, table_name: str, reset=False):
- """
- Get or create a named collection.
- :param name: Name of the collection
- :type name: str
- :return: Created collection
- :rtype: Collection
- """
- if not self.embedder_check:
- schema = pa.schema(
- [
- pa.field("doc", pa.string()),
- pa.field("metadata", pa.string()),
- pa.field("id", pa.string()),
- ]
- )
- else:
- schema = pa.schema(
- [
- pa.field("vector", pa.list_(pa.float32(), list_size=self.embedder.vector_dimension)),
- pa.field("doc", pa.string()),
- pa.field("metadata", pa.string()),
- pa.field("id", pa.string()),
- ]
- )
- if not reset:
- if table_name not in self.client.table_names():
- self.collection = self.client.create_table(table_name, schema=schema)
- else:
- self.client.drop_table(table_name)
- self.collection = self.client.create_table(table_name, schema=schema)
- self.collection = self.client[table_name]
- return self.collection
- def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
- """
- Get existing doc ids present in vector database
- :param ids: list of doc ids to check for existence
- :type ids: List[str]
- :param where: Optional. to filter data
- :type where: Dict[str, Any]
- :param limit: Optional. maximum number of documents
- :type limit: Optional[int]
- :return: Existing documents.
- :rtype: List[str]
- """
- if limit is not None:
- max_limit = limit
- else:
- max_limit = 3
- results = {"ids": [], "metadatas": []}
- where_clause = {}
- if where:
- where_clause = self._generate_where_clause(where)
- if ids is not None:
- records = (
- self.collection.to_lance().scanner(filter=f"id IN {tuple(ids)}", columns=["id"]).to_table().to_pydict()
- )
- for id in records["id"]:
- if where is not None:
- result = (
- self.collection.search(query=id, vector_column_name="id")
- .where(where_clause)
- .limit(max_limit)
- .to_list()
- )
- else:
- result = self.collection.search(query=id, vector_column_name="id").limit(max_limit).to_list()
- results["ids"] = [r["id"] for r in result]
- results["metadatas"] = [r["metadata"] for r in result]
- return results
- def add(
- self,
- documents: List[str],
- metadatas: List[object],
- ids: List[str],
- ) -> Any:
- """
- Add vectors to lancedb database
- :param documents: Documents
- :type documents: List[str]
- :param metadatas: Metadatas
- :type metadatas: List[object]
- :param ids: ids
- :type ids: List[str]
- """
- data = []
- to_ingest = list(zip(documents, metadatas, ids))
- if not self.embedder_check:
- for doc, meta, id in to_ingest:
- temp = {}
- temp["doc"] = doc
- temp["metadata"] = str(meta)
- temp["id"] = id
- data.append(temp)
- else:
- for doc, meta, id in to_ingest:
- temp = {}
- temp["doc"] = doc
- temp["vector"] = self.embedder.embedding_fn([doc])[0]
- temp["metadata"] = str(meta)
- temp["id"] = id
- data.append(temp)
- self.collection.add(data=data)
- def _format_result(self, results) -> list:
- """
- Format LanceDB results
- :param results: LanceDB query results to format.
- :type results: QueryResult
- :return: Formatted results
- :rtype: list[tuple[Document, float]]
- """
- return results.tolist()
- def query(
- self,
- input_query: str,
- n_results: int = 3,
- where: Optional[dict[str, any]] = None,
- raw_filter: Optional[dict[str, any]] = None,
- citations: bool = False,
- **kwargs: Optional[dict[str, any]],
- ) -> Union[list[tuple[str, dict]], list[str]]:
- """
- Query contents from vector database based on vector similarity
- :param input_query: query string
- :type input_query: str
- :param n_results: no of similar documents to fetch from database
- :type n_results: int
- :param where: to filter data
- :type where: dict[str, Any]
- :param raw_filter: Raw filter to apply
- :type raw_filter: dict[str, Any]
- :param citations: we use citations boolean param to return context along with the answer.
- :type citations: bool, default is False.
- :raises InvalidDimensionException: Dimensions do not match.
- :return: The content of the document that matched your query,
- along with url of the source and doc_id (if citations flag is true)
- :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
- """
- if where and raw_filter:
- raise ValueError("Both `where` and `raw_filter` cannot be used together.")
- try:
- query_embedding = self.embedder.embedding_fn(input_query)[0]
- result = self.collection.search(query_embedding).limit(n_results).to_list()
- except Exception as e:
- e.message()
- results_formatted = result
- contexts = []
- for result in results_formatted:
- if citations:
- metadata = result["metadata"]
- contexts.append((result["doc"], metadata))
- else:
- contexts.append(result["doc"])
- return contexts
- def set_collection_name(self, name: str):
- """
- Set the name of the collection. A collection is an isolated space for vectors.
- :param name: Name of the collection.
- :type name: str
- """
- if not isinstance(name, str):
- raise TypeError("Collection name must be a string")
- self.config.collection_name = name
- self._get_or_create_collection(self.config.collection_name)
- def count(self) -> int:
- """
- Count number of documents/chunks embedded in the database.
- :return: number of documents
- :rtype: int
- """
- return self.collection.count_rows()
- def delete(self, where):
- return self.collection.delete(where=where)
- def reset(self):
- """
- Resets the database. Deletes all embeddings irreversibly.
- """
- # Delete all data from the collection and recreate collection
- if self.config.allow_reset:
- try:
- self._get_or_create_collection(self.config.collection_name, reset=True)
- except ValueError:
- raise ValueError(
- "For safety reasons, resetting is disabled. "
- "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
- ) from None
- # Recreate
- else:
- print(
- "For safety reasons, resetting is disabled. "
- "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
- )
|