weaviate.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import copy
  2. import os
  3. from typing import Dict, List, Optional, Tuple, Union
  4. try:
  5. import weaviate
  6. except ImportError:
  7. raise ImportError(
  8. "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
  9. ) from None
  10. from embedchain.config.vectordb.weaviate import WeaviateDBConfig
  11. from embedchain.helper.json_serializable import register_deserializable
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class WeaviateDB(BaseVectorDB):
  15. """
  16. Weaviate as vector database
  17. """
  18. BATCH_SIZE = 100
  19. def __init__(
  20. self,
  21. config: Optional[WeaviateDBConfig] = None,
  22. ):
  23. """Weaviate as vector database.
  24. :param config: Weaviate database config, defaults to None
  25. :type config: WeaviateDBConfig, optional
  26. :raises ValueError: No config provided
  27. """
  28. if config is None:
  29. self.config = WeaviateDBConfig()
  30. else:
  31. if not isinstance(config, WeaviateDBConfig):
  32. raise TypeError(
  33. "config is not a `WeaviateDBConfig` instance. "
  34. "Please make sure the type is right and that you are passing an instance."
  35. )
  36. self.config = config
  37. self.client = weaviate.Client(
  38. url=os.environ.get("WEAVIATE_ENDPOINT"),
  39. auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
  40. **self.config.extra_params,
  41. )
  42. # Call parent init here because embedder is needed
  43. super().__init__(config=self.config)
  44. def _initialize(self):
  45. """
  46. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  47. """
  48. if not self.embedder:
  49. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  50. self.index_name = self._get_index_name()
  51. self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
  52. if not self.client.schema.exists(self.index_name):
  53. # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
  54. # The none vectorizer is crucial as we have our own custom embedding function
  55. """
  56. TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
  57. Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
  58. """
  59. class_obj = {
  60. "classes": [
  61. {
  62. "class": self.index_name,
  63. "vectorizer": "none",
  64. "properties": [
  65. {
  66. "name": "identifier",
  67. "dataType": ["text"],
  68. },
  69. {
  70. "name": "text",
  71. "dataType": ["text"],
  72. },
  73. {
  74. "name": "metadata",
  75. "dataType": [self.index_name + "_metadata"],
  76. },
  77. ],
  78. },
  79. {
  80. "class": self.index_name + "_metadata",
  81. "vectorizer": "none",
  82. "properties": [
  83. {
  84. "name": "data_type",
  85. "dataType": ["text"],
  86. },
  87. {
  88. "name": "doc_id",
  89. "dataType": ["text"],
  90. },
  91. {
  92. "name": "url",
  93. "dataType": ["text"],
  94. },
  95. {
  96. "name": "hash",
  97. "dataType": ["text"],
  98. },
  99. {
  100. "name": "app_id",
  101. "dataType": ["text"],
  102. },
  103. ],
  104. },
  105. ]
  106. }
  107. self.client.schema.create(class_obj)
  108. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  109. """
  110. Get existing doc ids present in vector database
  111. :param ids: _list of doc ids to check for existance
  112. :type ids: List[str]
  113. :param where: to filter data
  114. :type where: Dict[str, any]
  115. :return: ids
  116. :rtype: Set[str]
  117. """
  118. if ids is None or len(ids) == 0:
  119. return {"ids": []}
  120. existing_ids = []
  121. cursor = None
  122. has_iterated_once = False
  123. while cursor is not None or not has_iterated_once:
  124. has_iterated_once = True
  125. results = self._query_with_cursor(
  126. self.client.query.get(self.index_name, ["identifier"])
  127. .with_additional(["id"])
  128. .with_limit(self.BATCH_SIZE),
  129. cursor,
  130. )
  131. fetched_results = results["data"]["Get"].get(self.index_name, [])
  132. if len(fetched_results) == 0:
  133. break
  134. for result in fetched_results:
  135. existing_ids.append(result["identifier"])
  136. cursor = result["_additional"]["id"]
  137. return {"ids": existing_ids}
  138. def add(
  139. self,
  140. embeddings: List[List[float]],
  141. documents: List[str],
  142. metadatas: List[object],
  143. ids: List[str],
  144. skip_embedding: bool,
  145. ):
  146. """add data in vector database
  147. :param embeddings: list of embeddings for the corresponding documents to be added
  148. :type documents: List[List[float]]
  149. :param documents: list of texts to add
  150. :type documents: List[str]
  151. :param metadatas: list of metadata associated with docs
  152. :type metadatas: List[object]
  153. :param ids: ids of docs
  154. :type ids: List[str]
  155. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  156. generated or not
  157. :type skip_embedding: bool
  158. """
  159. print("Adding documents to Weaviate...")
  160. if not skip_embedding:
  161. embeddings = self.embedder.embedding_fn(documents)
  162. self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
  163. with self.client.batch as batch: # Initialize a batch process
  164. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  165. doc = {"identifier": id, "text": text}
  166. updated_metadata = {"text": text}
  167. if metadata is not None:
  168. updated_metadata.update(**metadata)
  169. obj_uuid = batch.add_data_object(
  170. data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding
  171. )
  172. metadata_uuid = batch.add_data_object(
  173. data_object=copy.deepcopy(updated_metadata),
  174. class_name=self.index_name + "_metadata",
  175. vector=embedding,
  176. )
  177. batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
  178. def query(
  179. self,
  180. input_query: List[str],
  181. n_results: int,
  182. where: Dict[str, any],
  183. skip_embedding: bool,
  184. citations: bool = False,
  185. ) -> Union[List[Tuple[str, str, str]], List[str]]:
  186. """
  187. query contents from vector database based on vector similarity
  188. :param input_query: list of query string
  189. :type input_query: List[str]
  190. :param n_results: no of similar documents to fetch from database
  191. :type n_results: int
  192. :param where: Optional. to filter data
  193. :type where: Dict[str, any]
  194. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  195. generated or not
  196. :type skip_embedding: bool
  197. :param citations: we use citations boolean param to return context along with the answer.
  198. :type citations: bool, default is False.
  199. :return: The content of the document that matched your query,
  200. along with url of the source and doc_id (if citations flag is true)
  201. :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
  202. """
  203. if not skip_embedding:
  204. query_vector = self.embedder.embedding_fn([input_query])[0]
  205. else:
  206. query_vector = input_query
  207. keys = set(where.keys() if where is not None else set())
  208. data_fields = ["text"]
  209. if citations:
  210. data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
  211. if len(keys.intersection(self.metadata_keys)) != 0:
  212. weaviate_where_operands = []
  213. for key in keys:
  214. if key in self.metadata_keys:
  215. weaviate_where_operands.append(
  216. {
  217. "path": ["metadata", self.index_name + "_metadata", key],
  218. "operator": "Equal",
  219. "valueText": where.get(key),
  220. }
  221. )
  222. if len(weaviate_where_operands) == 1:
  223. weaviate_where_clause = weaviate_where_operands[0]
  224. else:
  225. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  226. results = (
  227. self.client.query.get(self.index_name, data_fields)
  228. .with_where(weaviate_where_clause)
  229. .with_near_vector({"vector": query_vector})
  230. .with_limit(n_results)
  231. .do()
  232. )
  233. else:
  234. results = (
  235. self.client.query.get(self.index_name, data_fields)
  236. .with_near_vector({"vector": query_vector})
  237. .with_limit(n_results)
  238. .do()
  239. )
  240. docs = results["data"]["Get"].get(self.index_name)
  241. contexts = []
  242. for doc in docs:
  243. context = doc["text"]
  244. if citations:
  245. metadata = doc["metadata"][0]
  246. source = metadata["url"]
  247. doc_id = metadata["doc_id"]
  248. contexts.append((context, source, doc_id))
  249. else:
  250. contexts.append(context)
  251. return contexts
  252. def set_collection_name(self, name: str):
  253. """
  254. Set the name of the collection. A collection is an isolated space for vectors.
  255. :param name: Name of the collection.
  256. :type name: str
  257. """
  258. if not isinstance(name, str):
  259. raise TypeError("Collection name must be a string")
  260. self.config.collection_name = name
  261. def count(self) -> int:
  262. """
  263. Count number of documents/chunks embedded in the database.
  264. :return: number of documents
  265. :rtype: int
  266. """
  267. data = self.client.query.aggregate(self.index_name).with_meta_count().do()
  268. return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"]
  269. def _get_or_create_db(self):
  270. """Called during initialization"""
  271. return self.client
  272. def reset(self):
  273. """
  274. Resets the database. Deletes all embeddings irreversibly.
  275. """
  276. # Delete all data from the database
  277. self.client.batch.delete_objects(
  278. self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  279. )
  280. # Weaviate internally by default capitalizes the class name
  281. def _get_index_name(self) -> str:
  282. """Get the Weaviate index for a collection
  283. :return: Weaviate index
  284. :rtype: str
  285. """
  286. return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
  287. def _query_with_cursor(self, query, cursor):
  288. if cursor is not None:
  289. query.with_after(cursor)
  290. results = query.do()
  291. return results