weaviate.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import copy
  2. import os
  3. from typing import Optional, Union
  4. try:
  5. import weaviate
  6. except ImportError:
  7. raise ImportError(
  8. "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
  9. ) from None
  10. from embedchain.config.vectordb.weaviate import WeaviateDBConfig
  11. from embedchain.helpers.json_serializable import register_deserializable
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class WeaviateDB(BaseVectorDB):
  15. """
  16. Weaviate as vector database
  17. """
  18. BATCH_SIZE = 100
  19. def __init__(
  20. self,
  21. config: Optional[WeaviateDBConfig] = None,
  22. ):
  23. """Weaviate as vector database.
  24. :param config: Weaviate database config, defaults to None
  25. :type config: WeaviateDBConfig, optional
  26. :raises ValueError: No config provided
  27. """
  28. if config is None:
  29. self.config = WeaviateDBConfig()
  30. else:
  31. if not isinstance(config, WeaviateDBConfig):
  32. raise TypeError(
  33. "config is not a `WeaviateDBConfig` instance. "
  34. "Please make sure the type is right and that you are passing an instance."
  35. )
  36. self.config = config
  37. self.client = weaviate.Client(
  38. url=os.environ.get("WEAVIATE_ENDPOINT"),
  39. auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
  40. **self.config.extra_params,
  41. )
  42. # Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb.
  43. # This is needed to filter data while querying.
  44. self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
  45. # Call parent init here because embedder is needed
  46. super().__init__(config=self.config)
  47. def _initialize(self):
  48. """
  49. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  50. """
  51. if not self.embedder:
  52. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  53. self.index_name = self._get_index_name()
  54. if not self.client.schema.exists(self.index_name):
  55. # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
  56. # The none vectorizer is crucial as we have our own custom embedding function
  57. """
  58. TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
  59. Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
  60. """
  61. class_obj = {
  62. "classes": [
  63. {
  64. "class": self.index_name,
  65. "vectorizer": "none",
  66. "properties": [
  67. {
  68. "name": "identifier",
  69. "dataType": ["text"],
  70. },
  71. {
  72. "name": "text",
  73. "dataType": ["text"],
  74. },
  75. {
  76. "name": "metadata",
  77. "dataType": [self.index_name + "_metadata"],
  78. },
  79. ],
  80. },
  81. {
  82. "class": self.index_name + "_metadata",
  83. "vectorizer": "none",
  84. "properties": [
  85. {
  86. "name": "data_type",
  87. "dataType": ["text"],
  88. },
  89. {
  90. "name": "doc_id",
  91. "dataType": ["text"],
  92. },
  93. {
  94. "name": "url",
  95. "dataType": ["text"],
  96. },
  97. {
  98. "name": "hash",
  99. "dataType": ["text"],
  100. },
  101. {
  102. "name": "app_id",
  103. "dataType": ["text"],
  104. },
  105. ],
  106. },
  107. ]
  108. }
  109. self.client.schema.create(class_obj)
  110. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  111. """
  112. Get existing doc ids present in vector database
  113. :param ids: _list of doc ids to check for existance
  114. :type ids: list[str]
  115. :param where: to filter data
  116. :type where: dict[str, any]
  117. :return: ids
  118. :rtype: Set[str]
  119. """
  120. weaviate_where_operands = []
  121. if ids:
  122. for doc_id in ids:
  123. weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id})
  124. keys = set(where.keys() if where is not None else set())
  125. if len(keys) > 0:
  126. for key in keys:
  127. weaviate_where_operands.append(
  128. {
  129. "path": ["metadata", self.index_name + "_metadata", key],
  130. "operator": "Equal",
  131. "valueText": where.get(key),
  132. }
  133. )
  134. if len(weaviate_where_operands) == 1:
  135. weaviate_where_clause = weaviate_where_operands[0]
  136. else:
  137. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  138. existing_ids = []
  139. metadatas = []
  140. cursor = None
  141. offset = 0
  142. has_iterated_once = False
  143. query_metadata_keys = self.metadata_keys.union(keys)
  144. while cursor is not None or not has_iterated_once:
  145. has_iterated_once = True
  146. results = self._query_with_offset(
  147. self.client.query.get(
  148. self.index_name,
  149. [
  150. "identifier",
  151. weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)),
  152. ],
  153. )
  154. .with_where(weaviate_where_clause)
  155. .with_additional(["id"])
  156. .with_limit(limit or self.BATCH_SIZE),
  157. offset,
  158. )
  159. fetched_results = results["data"]["Get"].get(self.index_name, [])
  160. if not fetched_results:
  161. break
  162. for result in fetched_results:
  163. existing_ids.append(result["identifier"])
  164. metadatas.append(result["metadata"][0])
  165. cursor = result["_additional"]["id"]
  166. offset += 1
  167. if limit is not None and len(existing_ids) >= limit:
  168. break
  169. return {"ids": existing_ids, "metadatas": metadatas}
  170. def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
  171. """add data in vector database
  172. :param documents: list of texts to add
  173. :type documents: list[str]
  174. :param metadatas: list of metadata associated with docs
  175. :type metadatas: list[object]
  176. :param ids: ids of docs
  177. :type ids: list[str]
  178. """
  179. embeddings = self.embedder.embedding_fn(documents)
  180. self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
  181. with self.client.batch as batch: # Initialize a batch process
  182. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  183. doc = {"identifier": id, "text": text}
  184. updated_metadata = {"text": text}
  185. if metadata is not None:
  186. updated_metadata.update(**metadata)
  187. obj_uuid = batch.add_data_object(
  188. data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding
  189. )
  190. metadata_uuid = batch.add_data_object(
  191. data_object=copy.deepcopy(updated_metadata),
  192. class_name=self.index_name + "_metadata",
  193. vector=embedding,
  194. )
  195. batch.add_reference(
  196. obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs
  197. )
  198. def query(
  199. self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False
  200. ) -> Union[list[tuple[str, dict]], list[str]]:
  201. """
  202. query contents from vector database based on vector similarity
  203. :param input_query: list of query string
  204. :type input_query: list[str]
  205. :param n_results: no of similar documents to fetch from database
  206. :type n_results: int
  207. :param where: Optional. to filter data
  208. :type where: dict[str, any]
  209. :param citations: we use citations boolean param to return context along with the answer.
  210. :type citations: bool, default is False.
  211. :return: The content of the document that matched your query,
  212. along with url of the source and doc_id (if citations flag is true)
  213. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  214. """
  215. query_vector = self.embedder.embedding_fn([input_query])[0]
  216. keys = set(where.keys() if where is not None else set())
  217. data_fields = ["text"]
  218. query_metadata_keys = self.metadata_keys.union(keys)
  219. if citations:
  220. data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)))
  221. if len(keys) > 0:
  222. weaviate_where_operands = []
  223. for key in keys:
  224. weaviate_where_operands.append(
  225. {
  226. "path": ["metadata", self.index_name + "_metadata", key],
  227. "operator": "Equal",
  228. "valueText": where.get(key),
  229. }
  230. )
  231. if len(weaviate_where_operands) == 1:
  232. weaviate_where_clause = weaviate_where_operands[0]
  233. else:
  234. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  235. results = (
  236. self.client.query.get(self.index_name, data_fields)
  237. .with_where(weaviate_where_clause)
  238. .with_near_vector({"vector": query_vector})
  239. .with_limit(n_results)
  240. .with_additional(["distance"])
  241. .do()
  242. )
  243. else:
  244. results = (
  245. self.client.query.get(self.index_name, data_fields)
  246. .with_near_vector({"vector": query_vector})
  247. .with_limit(n_results)
  248. .with_additional(["distance"])
  249. .do()
  250. )
  251. if results["data"]["Get"].get(self.index_name) is None:
  252. return []
  253. docs = results["data"]["Get"].get(self.index_name)
  254. contexts = []
  255. for doc in docs:
  256. context = doc["text"]
  257. if citations:
  258. metadata = doc["metadata"][0]
  259. score = doc["_additional"]["distance"]
  260. metadata["score"] = score
  261. contexts.append((context, metadata))
  262. else:
  263. contexts.append(context)
  264. return contexts
  265. def set_collection_name(self, name: str):
  266. """
  267. Set the name of the collection. A collection is an isolated space for vectors.
  268. :param name: Name of the collection.
  269. :type name: str
  270. """
  271. if not isinstance(name, str):
  272. raise TypeError("Collection name must be a string")
  273. self.config.collection_name = name
  274. def count(self) -> int:
  275. """
  276. Count number of documents/chunks embedded in the database.
  277. :return: number of documents
  278. :rtype: int
  279. """
  280. data = self.client.query.aggregate(self.index_name).with_meta_count().do()
  281. return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"]
  282. def _get_or_create_db(self):
  283. """Called during initialization"""
  284. return self.client
  285. def reset(self):
  286. """
  287. Resets the database. Deletes all embeddings irreversibly.
  288. """
  289. # Delete all data from the database
  290. self.client.batch.delete_objects(
  291. self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  292. )
  293. # Weaviate internally by default capitalizes the class name
  294. def _get_index_name(self) -> str:
  295. """Get the Weaviate index for a collection
  296. :return: Weaviate index
  297. :rtype: str
  298. """
  299. return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_")
  300. @staticmethod
  301. def _query_with_offset(query, offset):
  302. if offset:
  303. query.with_offset(offset)
  304. results = query.do()
  305. return results
  306. def _generate_query(self, where: dict):
  307. weaviate_where_operands = []
  308. for key, value in where.items():
  309. weaviate_where_operands.append(
  310. {
  311. "path": ["metadata", self.index_name + "_metadata", key],
  312. "operator": "Equal",
  313. "valueText": value,
  314. }
  315. )
  316. if len(weaviate_where_operands) == 1:
  317. weaviate_where_clause = weaviate_where_operands[0]
  318. else:
  319. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  320. return weaviate_where_clause
  321. def delete(self, where: dict):
  322. """Delete from database.
  323. :param where: to filter data
  324. :type where: dict[str, any]
  325. """
  326. query = self._generate_query(where)
  327. self.client.batch.delete_objects(self.index_name, where=query)