weaviate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import copy
  2. import os
  3. from typing import Optional, Union
  4. try:
  5. import weaviate
  6. except ImportError:
  7. raise ImportError(
  8. "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
  9. ) from None
  10. from embedchain.config.vectordb.weaviate import WeaviateDBConfig
  11. from embedchain.helpers.json_serializable import register_deserializable
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class WeaviateDB(BaseVectorDB):
  15. """
  16. Weaviate as vector database
  17. """
  18. def __init__(
  19. self,
  20. config: Optional[WeaviateDBConfig] = None,
  21. ):
  22. """Weaviate as vector database.
  23. :param config: Weaviate database config, defaults to None
  24. :type config: WeaviateDBConfig, optional
  25. :raises ValueError: No config provided
  26. """
  27. if config is None:
  28. self.config = WeaviateDBConfig()
  29. else:
  30. if not isinstance(config, WeaviateDBConfig):
  31. raise TypeError(
  32. "config is not a `WeaviateDBConfig` instance. "
  33. "Please make sure the type is right and that you are passing an instance."
  34. )
  35. self.config = config
  36. self.client = weaviate.Client(
  37. url=os.environ.get("WEAVIATE_ENDPOINT"),
  38. auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
  39. **self.config.extra_params,
  40. )
  41. # Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb.
  42. # This is needed to filter data while querying.
  43. self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
  44. # Call parent init here because embedder is needed
  45. super().__init__(config=self.config)
  46. def _initialize(self):
  47. """
  48. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  49. """
  50. if not self.embedder:
  51. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  52. self.index_name = self._get_index_name()
  53. if not self.client.schema.exists(self.index_name):
  54. # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
  55. # The none vectorizer is crucial as we have our own custom embedding function
  56. """
  57. TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
  58. Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
  59. """
  60. class_obj = {
  61. "classes": [
  62. {
  63. "class": self.index_name,
  64. "vectorizer": "none",
  65. "properties": [
  66. {
  67. "name": "identifier",
  68. "dataType": ["text"],
  69. },
  70. {
  71. "name": "text",
  72. "dataType": ["text"],
  73. },
  74. {
  75. "name": "metadata",
  76. "dataType": [self.index_name + "_metadata"],
  77. },
  78. ],
  79. },
  80. {
  81. "class": self.index_name + "_metadata",
  82. "vectorizer": "none",
  83. "properties": [
  84. {
  85. "name": "data_type",
  86. "dataType": ["text"],
  87. },
  88. {
  89. "name": "doc_id",
  90. "dataType": ["text"],
  91. },
  92. {
  93. "name": "url",
  94. "dataType": ["text"],
  95. },
  96. {
  97. "name": "hash",
  98. "dataType": ["text"],
  99. },
  100. {
  101. "name": "app_id",
  102. "dataType": ["text"],
  103. },
  104. ],
  105. },
  106. ]
  107. }
  108. self.client.schema.create(class_obj)
  109. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  110. """
  111. Get existing doc ids present in vector database
  112. :param ids: _list of doc ids to check for existance
  113. :type ids: list[str]
  114. :param where: to filter data
  115. :type where: dict[str, any]
  116. :return: ids
  117. :rtype: Set[str]
  118. """
  119. weaviate_where_operands = []
  120. if ids:
  121. for doc_id in ids:
  122. weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id})
  123. keys = set(where.keys() if where is not None else set())
  124. if len(keys) > 0:
  125. for key in keys:
  126. weaviate_where_operands.append(
  127. {
  128. "path": ["metadata", self.index_name + "_metadata", key],
  129. "operator": "Equal",
  130. "valueText": where.get(key),
  131. }
  132. )
  133. if len(weaviate_where_operands) == 1:
  134. weaviate_where_clause = weaviate_where_operands[0]
  135. else:
  136. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  137. existing_ids = []
  138. metadatas = []
  139. cursor = None
  140. offset = 0
  141. has_iterated_once = False
  142. query_metadata_keys = self.metadata_keys.union(keys)
  143. while cursor is not None or not has_iterated_once:
  144. has_iterated_once = True
  145. results = self._query_with_offset(
  146. self.client.query.get(
  147. self.index_name,
  148. [
  149. "identifier",
  150. weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)),
  151. ],
  152. )
  153. .with_where(weaviate_where_clause)
  154. .with_additional(["id"])
  155. .with_limit(limit or self.config.batch_size),
  156. offset,
  157. )
  158. fetched_results = results["data"]["Get"].get(self.index_name, [])
  159. if not fetched_results:
  160. break
  161. for result in fetched_results:
  162. existing_ids.append(result["identifier"])
  163. metadatas.append(result["metadata"][0])
  164. cursor = result["_additional"]["id"]
  165. offset += 1
  166. if limit is not None and len(existing_ids) >= limit:
  167. break
  168. return {"ids": existing_ids, "metadatas": metadatas}
  169. def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
  170. """add data in vector database
  171. :param documents: list of texts to add
  172. :type documents: list[str]
  173. :param metadatas: list of metadata associated with docs
  174. :type metadatas: list[object]
  175. :param ids: ids of docs
  176. :type ids: list[str]
  177. """
  178. embeddings = self.embedder.embedding_fn(documents)
  179. self.client.batch.configure(batch_size=self.config.batch_size, timeout_retries=3) # Configure batch
  180. with self.client.batch as batch: # Initialize a batch process
  181. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  182. doc = {"identifier": id, "text": text}
  183. updated_metadata = {"text": text}
  184. if metadata is not None:
  185. updated_metadata.update(**metadata)
  186. obj_uuid = batch.add_data_object(
  187. data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding
  188. )
  189. metadata_uuid = batch.add_data_object(
  190. data_object=copy.deepcopy(updated_metadata),
  191. class_name=self.index_name + "_metadata",
  192. vector=embedding,
  193. )
  194. batch.add_reference(
  195. obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs
  196. )
  197. def query(
  198. self, input_query: str, n_results: int, where: dict[str, any], citations: bool = False
  199. ) -> Union[list[tuple[str, dict]], list[str]]:
  200. """
  201. query contents from vector database based on vector similarity
  202. :param input_query: query string
  203. :type input_query: str
  204. :param n_results: no of similar documents to fetch from database
  205. :type n_results: int
  206. :param where: Optional. to filter data
  207. :type where: dict[str, any]
  208. :param citations: we use citations boolean param to return context along with the answer.
  209. :type citations: bool, default is False.
  210. :return: The content of the document that matched your query,
  211. along with url of the source and doc_id (if citations flag is true)
  212. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  213. """
  214. query_vector = self.embedder.embedding_fn([input_query])[0]
  215. keys = set(where.keys() if where is not None else set())
  216. data_fields = ["text"]
  217. query_metadata_keys = self.metadata_keys.union(keys)
  218. if citations:
  219. data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)))
  220. if len(keys) > 0:
  221. weaviate_where_operands = []
  222. for key in keys:
  223. weaviate_where_operands.append(
  224. {
  225. "path": ["metadata", self.index_name + "_metadata", key],
  226. "operator": "Equal",
  227. "valueText": where.get(key),
  228. }
  229. )
  230. if len(weaviate_where_operands) == 1:
  231. weaviate_where_clause = weaviate_where_operands[0]
  232. else:
  233. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  234. results = (
  235. self.client.query.get(self.index_name, data_fields)
  236. .with_where(weaviate_where_clause)
  237. .with_near_vector({"vector": query_vector})
  238. .with_limit(n_results)
  239. .with_additional(["distance"])
  240. .do()
  241. )
  242. else:
  243. results = (
  244. self.client.query.get(self.index_name, data_fields)
  245. .with_near_vector({"vector": query_vector})
  246. .with_limit(n_results)
  247. .with_additional(["distance"])
  248. .do()
  249. )
  250. if results["data"]["Get"].get(self.index_name) is None:
  251. return []
  252. docs = results["data"]["Get"].get(self.index_name)
  253. contexts = []
  254. for doc in docs:
  255. context = doc["text"]
  256. if citations:
  257. metadata = doc["metadata"][0]
  258. score = doc["_additional"]["distance"]
  259. metadata["score"] = score
  260. contexts.append((context, metadata))
  261. else:
  262. contexts.append(context)
  263. return contexts
  264. def set_collection_name(self, name: str):
  265. """
  266. Set the name of the collection. A collection is an isolated space for vectors.
  267. :param name: Name of the collection.
  268. :type name: str
  269. """
  270. if not isinstance(name, str):
  271. raise TypeError("Collection name must be a string")
  272. self.config.collection_name = name
  273. def count(self) -> int:
  274. """
  275. Count number of documents/chunks embedded in the database.
  276. :return: number of documents
  277. :rtype: int
  278. """
  279. data = self.client.query.aggregate(self.index_name).with_meta_count().do()
  280. return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"]
  281. def _get_or_create_db(self):
  282. """Called during initialization"""
  283. return self.client
  284. def reset(self):
  285. """
  286. Resets the database. Deletes all embeddings irreversibly.
  287. """
  288. # Delete all data from the database
  289. self.client.batch.delete_objects(
  290. self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  291. )
  292. # Weaviate internally by default capitalizes the class name
  293. def _get_index_name(self) -> str:
  294. """Get the Weaviate index for a collection
  295. :return: Weaviate index
  296. :rtype: str
  297. """
  298. return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_")
  299. @staticmethod
  300. def _query_with_offset(query, offset):
  301. if offset:
  302. query.with_offset(offset)
  303. results = query.do()
  304. return results
  305. def _generate_query(self, where: dict):
  306. weaviate_where_operands = []
  307. for key, value in where.items():
  308. weaviate_where_operands.append(
  309. {
  310. "path": ["metadata", self.index_name + "_metadata", key],
  311. "operator": "Equal",
  312. "valueText": value,
  313. }
  314. )
  315. if len(weaviate_where_operands) == 1:
  316. weaviate_where_clause = weaviate_where_operands[0]
  317. else:
  318. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  319. return weaviate_where_clause
  320. def delete(self, where: dict):
  321. """Delete from database.
  322. :param where: to filter data
  323. :type where: dict[str, any]
  324. """
  325. query = self._generate_query(where)
  326. self.client.batch.delete_objects(self.index_name, where=query)