lancedb.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from typing import Any, Dict, List, Optional, Union
  2. import pyarrow as pa
  3. try:
  4. import lancedb
  5. except ImportError:
  6. raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None
  7. from embedchain.config.vectordb.lancedb import LanceDBConfig
  8. from embedchain.helpers.json_serializable import register_deserializable
  9. from embedchain.vectordb.base import BaseVectorDB
  10. @register_deserializable
  11. class LanceDB(BaseVectorDB):
  12. """
  13. LanceDB as vector database
  14. """
  15. BATCH_SIZE = 100
  16. def __init__(
  17. self,
  18. config: Optional[LanceDBConfig] = None,
  19. ):
  20. """LanceDB as vector database.
  21. :param config: LanceDB database config, defaults to None
  22. :type config: LanceDBConfig, optional
  23. """
  24. if config:
  25. self.config = config
  26. else:
  27. self.config = LanceDBConfig()
  28. self.client = lancedb.connect(self.config.dir or "~/.lancedb")
  29. self.embedder_check = True
  30. super().__init__(config=self.config)
  31. def _initialize(self):
  32. """
  33. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  34. """
  35. if not self.embedder:
  36. raise ValueError(
  37. "Embedder not set. Please set an embedder with `_set_embedder()` function before initialization."
  38. )
  39. else:
  40. # check embedder function is working or not
  41. try:
  42. self.embedder.embedding_fn("Hello LanceDB")
  43. except Exception:
  44. self.embedder_check = False
  45. self._get_or_create_collection(self.config.collection_name)
  46. def _get_or_create_db(self):
  47. """
  48. Called during initialization
  49. """
  50. return self.client
  51. def _generate_where_clause(self, where: Dict[str, any]) -> str:
  52. """
  53. This method generate where clause using dictionary containing attributes and their values
  54. """
  55. where_filters = ""
  56. if len(list(where.keys())) == 1:
  57. where_filters = f"{list(where.keys())[0]} = {list(where.values())[0]}"
  58. return where_filters
  59. where_items = list(where.items())
  60. where_count = len(where_items)
  61. for i, (key, value) in enumerate(where_items, start=1):
  62. condition = f"{key} = {value} AND "
  63. where_filters += condition
  64. if i == where_count:
  65. condition = f"{key} = {value}"
  66. where_filters += condition
  67. return where_filters
  68. def _get_or_create_collection(self, table_name: str, reset=False):
  69. """
  70. Get or create a named collection.
  71. :param name: Name of the collection
  72. :type name: str
  73. :return: Created collection
  74. :rtype: Collection
  75. """
  76. if not self.embedder_check:
  77. schema = pa.schema(
  78. [
  79. pa.field("doc", pa.string()),
  80. pa.field("metadata", pa.string()),
  81. pa.field("id", pa.string()),
  82. ]
  83. )
  84. else:
  85. schema = pa.schema(
  86. [
  87. pa.field("vector", pa.list_(pa.float32(), list_size=self.embedder.vector_dimension)),
  88. pa.field("doc", pa.string()),
  89. pa.field("metadata", pa.string()),
  90. pa.field("id", pa.string()),
  91. ]
  92. )
  93. if not reset:
  94. if table_name not in self.client.table_names():
  95. self.collection = self.client.create_table(table_name, schema=schema)
  96. else:
  97. self.client.drop_table(table_name)
  98. self.collection = self.client.create_table(table_name, schema=schema)
  99. self.collection = self.client[table_name]
  100. return self.collection
  101. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  102. """
  103. Get existing doc ids present in vector database
  104. :param ids: list of doc ids to check for existence
  105. :type ids: List[str]
  106. :param where: Optional. to filter data
  107. :type where: Dict[str, Any]
  108. :param limit: Optional. maximum number of documents
  109. :type limit: Optional[int]
  110. :return: Existing documents.
  111. :rtype: List[str]
  112. """
  113. if limit is not None:
  114. max_limit = limit
  115. else:
  116. max_limit = 3
  117. results = {"ids": [], "metadatas": []}
  118. where_clause = {}
  119. if where:
  120. where_clause = self._generate_where_clause(where)
  121. if ids is not None:
  122. records = (
  123. self.collection.to_lance().scanner(filter=f"id IN {tuple(ids)}", columns=["id"]).to_table().to_pydict()
  124. )
  125. for id in records["id"]:
  126. if where is not None:
  127. result = (
  128. self.collection.search(query=id, vector_column_name="id")
  129. .where(where_clause)
  130. .limit(max_limit)
  131. .to_list()
  132. )
  133. else:
  134. result = self.collection.search(query=id, vector_column_name="id").limit(max_limit).to_list()
  135. results["ids"] = [r["id"] for r in result]
  136. results["metadatas"] = [r["metadata"] for r in result]
  137. return results
  138. def add(
  139. self,
  140. documents: List[str],
  141. metadatas: List[object],
  142. ids: List[str],
  143. ) -> Any:
  144. """
  145. Add vectors to lancedb database
  146. :param documents: Documents
  147. :type documents: List[str]
  148. :param metadatas: Metadatas
  149. :type metadatas: List[object]
  150. :param ids: ids
  151. :type ids: List[str]
  152. """
  153. data = []
  154. to_ingest = list(zip(documents, metadatas, ids))
  155. if not self.embedder_check:
  156. for doc, meta, id in to_ingest:
  157. temp = {}
  158. temp["doc"] = doc
  159. temp["metadata"] = str(meta)
  160. temp["id"] = id
  161. data.append(temp)
  162. else:
  163. for doc, meta, id in to_ingest:
  164. temp = {}
  165. temp["doc"] = doc
  166. temp["vector"] = self.embedder.embedding_fn([doc])[0]
  167. temp["metadata"] = str(meta)
  168. temp["id"] = id
  169. data.append(temp)
  170. self.collection.add(data=data)
  171. def _format_result(self, results) -> list:
  172. """
  173. Format LanceDB results
  174. :param results: LanceDB query results to format.
  175. :type results: QueryResult
  176. :return: Formatted results
  177. :rtype: list[tuple[Document, float]]
  178. """
  179. return results.tolist()
  180. def query(
  181. self,
  182. input_query: str,
  183. n_results: int = 3,
  184. where: Optional[dict[str, any]] = None,
  185. raw_filter: Optional[dict[str, any]] = None,
  186. citations: bool = False,
  187. **kwargs: Optional[dict[str, any]],
  188. ) -> Union[list[tuple[str, dict]], list[str]]:
  189. """
  190. Query contents from vector database based on vector similarity
  191. :param input_query: query string
  192. :type input_query: str
  193. :param n_results: no of similar documents to fetch from database
  194. :type n_results: int
  195. :param where: to filter data
  196. :type where: dict[str, Any]
  197. :param raw_filter: Raw filter to apply
  198. :type raw_filter: dict[str, Any]
  199. :param citations: we use citations boolean param to return context along with the answer.
  200. :type citations: bool, default is False.
  201. :raises InvalidDimensionException: Dimensions do not match.
  202. :return: The content of the document that matched your query,
  203. along with url of the source and doc_id (if citations flag is true)
  204. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  205. """
  206. if where and raw_filter:
  207. raise ValueError("Both `where` and `raw_filter` cannot be used together.")
  208. try:
  209. query_embedding = self.embedder.embedding_fn(input_query)[0]
  210. result = self.collection.search(query_embedding).limit(n_results).to_list()
  211. except Exception as e:
  212. e.message()
  213. results_formatted = result
  214. contexts = []
  215. for result in results_formatted:
  216. if citations:
  217. metadata = result["metadata"]
  218. contexts.append((result["doc"], metadata))
  219. else:
  220. contexts.append(result["doc"])
  221. return contexts
  222. def set_collection_name(self, name: str):
  223. """
  224. Set the name of the collection. A collection is an isolated space for vectors.
  225. :param name: Name of the collection.
  226. :type name: str
  227. """
  228. if not isinstance(name, str):
  229. raise TypeError("Collection name must be a string")
  230. self.config.collection_name = name
  231. self._get_or_create_collection(self.config.collection_name)
  232. def count(self) -> int:
  233. """
  234. Count number of documents/chunks embedded in the database.
  235. :return: number of documents
  236. :rtype: int
  237. """
  238. return self.collection.count_rows()
  239. def delete(self, where):
  240. return self.collection.delete(where=where)
  241. def reset(self):
  242. """
  243. Resets the database. Deletes all embeddings irreversibly.
  244. """
  245. # Delete all data from the collection and recreate collection
  246. if self.config.allow_reset:
  247. try:
  248. self._get_or_create_collection(self.config.collection_name, reset=True)
  249. except ValueError:
  250. raise ValueError(
  251. "For safety reasons, resetting is disabled. "
  252. "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
  253. ) from None
  254. # Recreate
  255. else:
  256. print(
  257. "For safety reasons, resetting is disabled. "
  258. "Please enable it by setting `allow_reset=True` in your LanceDbConfig"
  259. )