weaviate.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import copy
  2. import os
  3. from typing import Optional, Union
  4. try:
  5. import weaviate
  6. except ImportError:
  7. raise ImportError(
  8. "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
  9. ) from None
  10. from embedchain.config.vectordb.weaviate import WeaviateDBConfig
  11. from embedchain.helpers.json_serializable import register_deserializable
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class WeaviateDB(BaseVectorDB):
  15. """
  16. Weaviate as vector database
  17. """
  18. BATCH_SIZE = 100
  19. def __init__(
  20. self,
  21. config: Optional[WeaviateDBConfig] = None,
  22. ):
  23. """Weaviate as vector database.
  24. :param config: Weaviate database config, defaults to None
  25. :type config: WeaviateDBConfig, optional
  26. :raises ValueError: No config provided
  27. """
  28. if config is None:
  29. self.config = WeaviateDBConfig()
  30. else:
  31. if not isinstance(config, WeaviateDBConfig):
  32. raise TypeError(
  33. "config is not a `WeaviateDBConfig` instance. "
  34. "Please make sure the type is right and that you are passing an instance."
  35. )
  36. self.config = config
  37. self.client = weaviate.Client(
  38. url=os.environ.get("WEAVIATE_ENDPOINT"),
  39. auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
  40. **self.config.extra_params,
  41. )
  42. # Call parent init here because embedder is needed
  43. super().__init__(config=self.config)
  44. def _initialize(self):
  45. """
  46. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  47. """
  48. if not self.embedder:
  49. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  50. self.index_name = self._get_index_name()
  51. self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
  52. if not self.client.schema.exists(self.index_name):
  53. # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
  54. # The none vectorizer is crucial as we have our own custom embedding function
  55. """
  56. TODO: wait for weaviate to add indexing on `object[]` data-type so that we can add filter while querying.
  57. Once that is done, change `dataType` of "metadata" field to `object[]` and update the query below.
  58. """
  59. class_obj = {
  60. "classes": [
  61. {
  62. "class": self.index_name,
  63. "vectorizer": "none",
  64. "properties": [
  65. {
  66. "name": "identifier",
  67. "dataType": ["text"],
  68. },
  69. {
  70. "name": "text",
  71. "dataType": ["text"],
  72. },
  73. {
  74. "name": "metadata",
  75. "dataType": [self.index_name + "_metadata"],
  76. },
  77. ],
  78. },
  79. {
  80. "class": self.index_name + "_metadata",
  81. "vectorizer": "none",
  82. "properties": [
  83. {
  84. "name": "data_type",
  85. "dataType": ["text"],
  86. },
  87. {
  88. "name": "doc_id",
  89. "dataType": ["text"],
  90. },
  91. {
  92. "name": "url",
  93. "dataType": ["text"],
  94. },
  95. {
  96. "name": "hash",
  97. "dataType": ["text"],
  98. },
  99. {
  100. "name": "app_id",
  101. "dataType": ["text"],
  102. },
  103. ],
  104. },
  105. ]
  106. }
  107. self.client.schema.create(class_obj)
  108. def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
  109. """
  110. Get existing doc ids present in vector database
  111. :param ids: _list of doc ids to check for existance
  112. :type ids: list[str]
  113. :param where: to filter data
  114. :type where: dict[str, any]
  115. :return: ids
  116. :rtype: Set[str]
  117. """
  118. if ids is None or len(ids) == 0:
  119. return {"ids": []}
  120. existing_ids = []
  121. cursor = None
  122. has_iterated_once = False
  123. while cursor is not None or not has_iterated_once:
  124. has_iterated_once = True
  125. results = self._query_with_cursor(
  126. self.client.query.get(self.index_name, ["identifier"])
  127. .with_additional(["id"])
  128. .with_limit(self.BATCH_SIZE),
  129. cursor,
  130. )
  131. fetched_results = results["data"]["Get"].get(self.index_name, [])
  132. if len(fetched_results) == 0:
  133. break
  134. for result in fetched_results:
  135. existing_ids.append(result["identifier"])
  136. cursor = result["_additional"]["id"]
  137. return {"ids": existing_ids}
  138. def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
  139. """add data in vector database
  140. :param documents: list of texts to add
  141. :type documents: list[str]
  142. :param metadatas: list of metadata associated with docs
  143. :type metadatas: list[object]
  144. :param ids: ids of docs
  145. :type ids: list[str]
  146. """
  147. embeddings = self.embedder.embedding_fn(documents)
  148. self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
  149. with self.client.batch as batch: # Initialize a batch process
  150. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  151. doc = {"identifier": id, "text": text}
  152. updated_metadata = {"text": text}
  153. if metadata is not None:
  154. updated_metadata.update(**metadata)
  155. obj_uuid = batch.add_data_object(
  156. data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding
  157. )
  158. metadata_uuid = batch.add_data_object(
  159. data_object=copy.deepcopy(updated_metadata),
  160. class_name=self.index_name + "_metadata",
  161. vector=embedding,
  162. )
  163. batch.add_reference(
  164. obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs
  165. )
  166. def query(
  167. self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False
  168. ) -> Union[list[tuple[str, dict]], list[str]]:
  169. """
  170. query contents from vector database based on vector similarity
  171. :param input_query: list of query string
  172. :type input_query: list[str]
  173. :param n_results: no of similar documents to fetch from database
  174. :type n_results: int
  175. :param where: Optional. to filter data
  176. :type where: dict[str, any]
  177. :param citations: we use citations boolean param to return context along with the answer.
  178. :type citations: bool, default is False.
  179. :return: The content of the document that matched your query,
  180. along with url of the source and doc_id (if citations flag is true)
  181. :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
  182. """
  183. query_vector = self.embedder.embedding_fn([input_query])[0]
  184. keys = set(where.keys() if where is not None else set())
  185. data_fields = ["text"]
  186. if citations:
  187. data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
  188. if len(keys.intersection(self.metadata_keys)) != 0:
  189. weaviate_where_operands = []
  190. for key in keys:
  191. if key in self.metadata_keys:
  192. weaviate_where_operands.append(
  193. {
  194. "path": ["metadata", self.index_name + "_metadata", key],
  195. "operator": "Equal",
  196. "valueText": where.get(key),
  197. }
  198. )
  199. if len(weaviate_where_operands) == 1:
  200. weaviate_where_clause = weaviate_where_operands[0]
  201. else:
  202. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  203. results = (
  204. self.client.query.get(self.index_name, data_fields)
  205. .with_where(weaviate_where_clause)
  206. .with_near_vector({"vector": query_vector})
  207. .with_limit(n_results)
  208. .with_additional(["distance"])
  209. .do()
  210. )
  211. else:
  212. results = (
  213. self.client.query.get(self.index_name, data_fields)
  214. .with_near_vector({"vector": query_vector})
  215. .with_limit(n_results)
  216. .with_additional(["distance"])
  217. .do()
  218. )
  219. docs = results["data"]["Get"].get(self.index_name)
  220. contexts = []
  221. for doc in docs:
  222. context = doc["text"]
  223. if citations:
  224. metadata = doc["metadata"][0]
  225. score = doc["_additional"]["distance"]
  226. metadata["score"] = score
  227. contexts.append((context, metadata))
  228. else:
  229. contexts.append(context)
  230. return contexts
  231. def set_collection_name(self, name: str):
  232. """
  233. Set the name of the collection. A collection is an isolated space for vectors.
  234. :param name: Name of the collection.
  235. :type name: str
  236. """
  237. if not isinstance(name, str):
  238. raise TypeError("Collection name must be a string")
  239. self.config.collection_name = name
  240. def count(self) -> int:
  241. """
  242. Count number of documents/chunks embedded in the database.
  243. :return: number of documents
  244. :rtype: int
  245. """
  246. data = self.client.query.aggregate(self.index_name).with_meta_count().do()
  247. return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"]
  248. def _get_or_create_db(self):
  249. """Called during initialization"""
  250. return self.client
  251. def reset(self):
  252. """
  253. Resets the database. Deletes all embeddings irreversibly.
  254. """
  255. # Delete all data from the database
  256. self.client.batch.delete_objects(
  257. self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  258. )
  259. # Weaviate internally by default capitalizes the class name
  260. def _get_index_name(self) -> str:
  261. """Get the Weaviate index for a collection
  262. :return: Weaviate index
  263. :rtype: str
  264. """
  265. return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
  266. @staticmethod
  267. def _query_with_cursor(query, cursor):
  268. if cursor is not None:
  269. query.with_after(cursor)
  270. results = query.do()
  271. return results