weaviate.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import copy
  2. import os
  3. from typing import Dict, List, Optional, Tuple
  4. try:
  5. import weaviate
  6. except ImportError:
  7. raise ImportError(
  8. "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
  9. ) from None
  10. from embedchain.config.vectordb.weaviate import WeaviateDBConfig
  11. from embedchain.helper.json_serializable import register_deserializable
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class WeaviateDB(BaseVectorDB):
  15. """
  16. Weaviate as vector database
  17. """
  18. BATCH_SIZE = 100
  19. def __init__(
  20. self,
  21. config: Optional[WeaviateDBConfig] = None,
  22. ):
  23. """Weaviate as vector database.
  24. :param config: Weaviate database config, defaults to None
  25. :type config: WeaviateDBConfig, optional
  26. :raises ValueError: No config provided
  27. """
  28. if config is None:
  29. self.config = WeaviateDBConfig()
  30. else:
  31. if not isinstance(config, WeaviateDBConfig):
  32. raise TypeError(
  33. "config is not a `WeaviateDBConfig` instance. "
  34. "Please make sure the type is right and that you are passing an instance."
  35. )
  36. self.config = config
  37. self.client = weaviate.Client(
  38. url=os.environ.get("WEAVIATE_ENDPOINT"),
  39. auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
  40. **self.config.extra_params,
  41. )
  42. # Call parent init here because embedder is needed
  43. super().__init__(config=self.config)
  44. def _initialize(self):
  45. """
  46. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  47. """
  48. if not self.embedder:
  49. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  50. self.index_name = self._get_index_name()
  51. self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
  52. if not self.client.schema.exists(self.index_name):
  53. # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
  54. # The none vectorizer is crucial as we have our own custom embedding function
  55. class_obj = {
  56. "classes": [
  57. {
  58. "class": self.index_name,
  59. "vectorizer": "none",
  60. "properties": [
  61. {
  62. "name": "identifier",
  63. "dataType": ["text"],
  64. },
  65. {
  66. "name": "text",
  67. "dataType": ["text"],
  68. },
  69. {
  70. "name": "metadata",
  71. "dataType": [self.index_name + "_metadata"],
  72. },
  73. ],
  74. },
  75. {
  76. "class": self.index_name + "_metadata",
  77. "vectorizer": "none",
  78. "properties": [
  79. {
  80. "name": "data_type",
  81. "dataType": ["text"],
  82. },
  83. {
  84. "name": "doc_id",
  85. "dataType": ["text"],
  86. },
  87. {
  88. "name": "url",
  89. "dataType": ["text"],
  90. },
  91. {
  92. "name": "hash",
  93. "dataType": ["text"],
  94. },
  95. {
  96. "name": "app_id",
  97. "dataType": ["text"],
  98. },
  99. {
  100. "name": "text",
  101. "dataType": ["text"],
  102. },
  103. ],
  104. },
  105. ]
  106. }
  107. self.client.schema.create(class_obj)
  108. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  109. """
  110. Get existing doc ids present in vector database
  111. :param ids: _list of doc ids to check for existance
  112. :type ids: List[str]
  113. :param where: to filter data
  114. :type where: Dict[str, any]
  115. :return: ids
  116. :rtype: Set[str]
  117. """
  118. if ids is None or len(ids) == 0:
  119. return {"ids": []}
  120. existing_ids = []
  121. cursor = None
  122. has_iterated_once = False
  123. while cursor is not None or not has_iterated_once:
  124. has_iterated_once = True
  125. results = self._query_with_cursor(
  126. self.client.query.get(self.index_name, ["identifier"])
  127. .with_additional(["id"])
  128. .with_limit(self.BATCH_SIZE),
  129. cursor,
  130. )
  131. fetched_results = results["data"]["Get"].get(self.index_name, [])
  132. if len(fetched_results) == 0:
  133. break
  134. for result in fetched_results:
  135. existing_ids.append(result["identifier"])
  136. cursor = result["_additional"]["id"]
  137. return {"ids": existing_ids}
  138. def add(
  139. self,
  140. embeddings: List[List[float]],
  141. documents: List[str],
  142. metadatas: List[object],
  143. ids: List[str],
  144. skip_embedding: bool,
  145. ):
  146. """add data in vector database
  147. :param embeddings: list of embeddings for the corresponding documents to be added
  148. :type documents: List[List[float]]
  149. :param documents: list of texts to add
  150. :type documents: List[str]
  151. :param metadatas: list of metadata associated with docs
  152. :type metadatas: List[object]
  153. :param ids: ids of docs
  154. :type ids: List[str]
  155. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  156. generated or not
  157. :type skip_embedding: bool
  158. """
  159. print("Adding documents to Weaviate...")
  160. if not skip_embedding:
  161. embeddings = self.embedder.embedding_fn(documents)
  162. self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
  163. with self.client.batch as batch: # Initialize a batch process
  164. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  165. doc = {"identifier": id, "text": text}
  166. updated_metadata = {"text": text}
  167. if metadata is not None:
  168. updated_metadata.update(**metadata)
  169. obj_uuid = batch.add_data_object(
  170. data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding
  171. )
  172. metadata_uuid = batch.add_data_object(
  173. data_object=copy.deepcopy(updated_metadata),
  174. class_name=self.index_name + "_metadata",
  175. vector=embedding,
  176. )
  177. batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
  178. def query(
  179. self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool
  180. ) -> List[Tuple[str, str, str]]:
  181. """
  182. query contents from vector database based on vector similarity
  183. :param input_query: list of query string
  184. :type input_query: List[str]
  185. :param n_results: no of similar documents to fetch from database
  186. :type n_results: int
  187. :param where: Optional. to filter data
  188. :type where: Dict[str, any]
  189. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  190. generated or not
  191. :type skip_embedding: bool
  192. :return: The context of the document that matched your query, url of the source, doc_id
  193. :rtype: List[Tuple[str,str,str]]
  194. """
  195. if not skip_embedding:
  196. query_vector = self.embedder.embedding_fn([input_query])[0]
  197. else:
  198. query_vector = input_query
  199. keys = set(where.keys() if where is not None else set())
  200. data_fields = ["text"]
  201. if len(keys.intersection(self.metadata_keys)) != 0:
  202. weaviate_where_operands = []
  203. for key in keys:
  204. if key in self.metadata_keys:
  205. weaviate_where_operands.append(
  206. {
  207. "path": ["metadata", self.index_name + "_metadata", key],
  208. "operator": "Equal",
  209. "valueText": where.get(key),
  210. }
  211. )
  212. if len(weaviate_where_operands) == 1:
  213. weaviate_where_clause = weaviate_where_operands[0]
  214. else:
  215. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  216. results = (
  217. self.client.query.get(self.index_name, data_fields)
  218. .with_where(weaviate_where_clause)
  219. .with_near_vector({"vector": query_vector})
  220. .with_limit(n_results)
  221. .do()
  222. )
  223. else:
  224. results = (
  225. self.client.query.get(self.index_name, data_fields)
  226. .with_near_vector({"vector": query_vector})
  227. .with_limit(n_results)
  228. .do()
  229. )
  230. contexts = results["data"]["Get"].get(self.index_name)
  231. return contexts
  232. def set_collection_name(self, name: str):
  233. """
  234. Set the name of the collection. A collection is an isolated space for vectors.
  235. :param name: Name of the collection.
  236. :type name: str
  237. """
  238. if not isinstance(name, str):
  239. raise TypeError("Collection name must be a string")
  240. self.config.collection_name = name
  241. def count(self) -> int:
  242. """
  243. Count number of documents/chunks embedded in the database.
  244. :return: number of documents
  245. :rtype: int
  246. """
  247. data = self.client.query.aggregate(self.index_name).with_meta_count().do()
  248. return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"]
  249. def _get_or_create_db(self):
  250. """Called during initialization"""
  251. return self.client
  252. def reset(self):
  253. """
  254. Resets the database. Deletes all embeddings irreversibly.
  255. """
  256. # Delete all data from the database
  257. self.client.batch.delete_objects(
  258. self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  259. )
  260. # Weaviate internally by default capitalizes the class name
  261. def _get_index_name(self) -> str:
  262. """Get the Weaviate index for a collection
  263. :return: Weaviate index
  264. :rtype: str
  265. """
  266. return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
  267. def _query_with_cursor(self, query, cursor):
  268. if cursor is not None:
  269. query.with_after(cursor)
  270. results = query.do()
  271. return results