lancedb.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. from typing import Any, Dict, List, Optional, Union
  2. import pyarrow as pa
  3. try:
  4. import lancedb
  5. except ImportError:
  6. raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
  7. from embedchain.config.vectordb.lancedb import LanceDBConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. from embedchain.vectordb.base import BaseVectorDB
  10. @register_deserializable
  11. class LanceDB(BaseVectorDB):
  12. """
  13. LanceDB as vector database
  14. """
  15. def __init__(
  16. self,
  17. config: Optional[LanceDBConfig] = None,
  18. ):
  19. """LanceDB as vector database.
  20. :param config: LanceDB database config, defaults to None
  21. :type config: LanceDBConfig, optional
  22. """
  23. if config:
  24. self.config = config
  25. else:
  26. self.config = LanceDBConfig()
  27. self.client = lancedb.connect(self.config.dir or "~/.lancedb")
  28. self.embedder_check = True
  29. super().__init__(config=self.config)
  30. def _initialize(self):
  31. """
  32. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  33. """
  34. if not self.embedder:
  35. raise ValueError(
  36. "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization."
  37. )
  38. else:
  39. # check embedder function is working or not
  40. try:
  41. self.embedder.embedding_fn("Hello LanceDB")
  42. except Exception:
  43. self.embedder_check = False
  44. self._get_or_create_collection(self.config.collection_name)
  45. def _get_or_create_db(self):
  46. """
  47. Called during initialization
  48. """
  49. return self.client
  50. def _generate_where_clause(self, where: Dict[str, any]) -> str:
  51. """
  52. This method generate where clause using dictionary containing attributes and their values
  53. """
  54. where_filters = ""
  55. if len(list(where.keys())) == 1:
  56. where_filters = f"{list(where.keys())[0]} = {list(where.values())[0]}"
  57. return where_filters
  58. where_items = list(where.items())
  59. where_count = len(where_items)
  60. for i, (key, value) in enumerate(where_items, start=1):
  61. condition = f"{key} = {value} AND "
  62. where_filters += condition
  63. if i == where_count:
  64. condition = f"{key} = {value}"
  65. where_filters += condition
  66. return where_filters
  67. def _get_or_create_collection(self, table_name: str, reset=False):
  68. """
  69. Get or create a named collection.
  70. :param name: Name of the collection
  71. :type name: str
  72. :return: Created collection
  73. :rtype: Collection
  74. """
  75. if not self.embedder_check:
  76. schema = pa.schema(
  77. [
  78. pa.field("doc", pa.string()),
  79. pa.field("metadata", pa.string()),
  80. pa.field("id", pa.string()),
  81. ]
  82. )
  83. else:
  84. schema = pa.schema(
  85. [
  86. pa.field("vector", pa.list_(pa.float32(), list_size=self.embedder.vector_dimension)),
  87. pa.field("doc", pa.string()),
  88. pa.field("metadata", pa.string()),
  89. pa.field("id", pa.string()),
  90. ]
  91. )
  92. if not reset:
  93. if table_name not in self.client.table_names():
  94. self.collection = self.client.create_table(table_name, schema=schema)
  95. else:
  96. self.client.drop_table(table_name)
  97. self.collection = self.client.create_table(table_name, schema=schema)
  98. self.collection = self.client[table_name]
  99. return self.collection
  100. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  101. """
  102. Get existing doc ids present in vector database
  103. :param ids: list of doc ids to check for existence
  104. :type ids: List[str]
  105. :param where: Optional. to filter data
  106. :type where: Dict[str, Any]
  107. :param limit: Optional. maximum number of documents
  108. :type limit: Optional[int]
  109. :return: Existing documents.
  110. :rtype: List[str]
  111. """
  112. if limit is not None:
  113. max_limit = limit
  114. else:
  115. max_limit = 3
  116. results = {"ids": [], "metadatas": []}
  117. where_clause = {}
  118. if where:
  119. where_clause = self._generate_where_clause(where)
  120. if ids is not None:
  121. records = (
  122. self.collection.to_lance().scanner(filter=f"id IN {tuple(ids)}", columns=["id"]).to_table().to_pydict()
  123. )
  124. for id in records["id"]:
  125. if where is not None:
  126. result = (
  127. self.collection.search(query=id, vector_column_name="id")
  128. .where(where_clause)
  129. .limit(max_limit)
  130. .to_list()
  131. )
  132. else:
  133. result = self.collection.search(query=id, vector_column_name="id").limit(max_limit).to_list()
  134. results["ids"] = [r["id"] for r in result]
  135. results["metadatas"] = [r["metadata"] for r in result]
  136. return results
  137. def add(
  138. self,
  139. documents: List[str],
  140. metadatas: List[object],
  141. ids: List[str],
  142. ) -> Any:
  143. """
  144. Add vectors to lancedb database
  145. :param documents: Documents
  146. :type documents: List[str]
  147. :param metadatas: Metadatas
  148. :type metadatas: List[object]
  149. :param ids: ids
  150. :type ids: List[str]
  151. """
  152. data = []
  153. to_ingest = list(zip(documents, metadatas, ids))
  154. if not self.embedder_check:
  155. for doc, meta, id in to_ingest:
  156. temp = {}
  157. temp["doc"] = doc
  158. temp["metadata"] = str(meta)
  159. temp["id"] = id
  160. data.append(temp)
  161. else:
  162. for doc, meta, id in to_ingest:
  163. temp = {}
  164. temp["doc"] = doc
  165. temp["vector"] = self.embedder.embedding_fn([doc])[0]
  166. temp["metadata"] = str(meta)
  167. temp["id"] = id
  168. data.append(temp)
  169. self.collection.add(data=data)
  170. def _format_result(self, results) -> list:
  171. """
  172. Format LanceDB results
  173. :param results: LanceDB query results to format.
  174. :type results: QueryResult
  175. :return: Formatted results
  176. :rtype: list[tuple[Document, float]]
  177. """
  178. return results.tolist()
  179. def query(
  180. self,
  181. input_query: str,
  182. n_results: int = 3,
  183. where: Optional[dict[str, any]] = None,
  184. raw_filter: Optional[dict[str, any]] = None,
  185. citations: bool = False,
  186. **kwargs: Optional[dict[str, any]],
  187. ) -> Union[list[tuple[str, dict]], list[str]]:
  188. """
  189. Query contents from vector database based on vector similarity
  190. :param input_query: query string
  191. :type input_query: str
  192. :param n_results: no of similar documents to fetch from database
  193. :type n_results: int
  194. :param where: to filter data
  195. :type where: dict[str, Any]
  196. :param raw_filter: Raw filter to apply
  197. :type raw_filter: dict[str, Any]
  198. :param citations: we use citations boolean param to return context along with the answer.
  199. :type citations: bool, default is False.
  200. :raises InvalidDimensionException: Dimensions do not match.
  201. :return: The content of the document that matched your query,
  202. along with url of the source and doc_id (if citations flag is true)
  203. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  204. """
  205. if where and raw_filter:
  206. raise ValueError("Both `where` and `raw_filter` cannot be used together.")
  207. try:
  208. query_embedding = self.embedder.embedding_fn(input_query)[0]
  209. result = self.collection.search(query_embedding).limit(n_results).to_list()
  210. except Exception as e:
  211. e.message()
  212. results_formatted = result
  213. contexts = []
  214. for result in results_formatted:
  215. if citations:
  216. metadata = result["metadata"]
  217. contexts.append((result["doc"], metadata))
  218. else:
  219. contexts.append(result["doc"])
  220. return contexts
  221. def set_collection_name(self, name: str):
  222. """
  223. Set the name of the collection. A collection is an isolated space for vectors.
  224. :param name: Name of the collection.
  225. :type name: str
  226. """
  227. if not isinstance(name, str):
  228. raise TypeError("Collection name must be a string")
  229. self.config.collection_name = name
  230. self._get_or_create_collection(self.config.collection_name)
  231. def count(self) -> int:
  232. """
  233. Count number of documents/chunks embedded in the database.
  234. :return: number of documents
  235. :rtype: int
  236. """
  237. return self.collection.count_rows()
  238. def delete(self, where):
  239. return self.collection.delete(where=where)
  240. def reset(self):
  241. """
  242. Resets the database. Deletes all embeddings irreversibly.
  243. """
  244. # Delete all data from the collection and recreate collection
  245. if self.config.allow_reset:
  246. try:
  247. self._get_or_create_collection(self.config.collection_name, reset=True)
  248. except ValueError:
  249. raise ValueError(
  250. "For safety reasons, resetting is disabled. "
  251. "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
  252. ) from None
  253. # Recreate
  254. else:
  255. print(
  256. "For safety reasons, resetting is disabled. "
  257. "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
  258. )