weaviate.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import copy
  2. import os
  3. from typing import Dict, List, Optional
  4. try:
  5. import weaviate
  6. except ImportError:
  7. raise ImportError(
  8. "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
  9. ) from None
  10. from embedchain.config.vectordb.weaviate import WeaviateDBConfig
  11. from embedchain.helper.json_serializable import register_deserializable
  12. from embedchain.vectordb.base import BaseVectorDB
  13. @register_deserializable
  14. class WeaviateDB(BaseVectorDB):
  15. """
  16. Weaviate as vector database
  17. """
  18. BATCH_SIZE = 100
  19. def __init__(
  20. self,
  21. config: Optional[WeaviateDBConfig] = None,
  22. ):
  23. """Weaviate as vector database.
  24. :param config: Weaviate database config, defaults to None
  25. :type config: WeaviateDBConfig, optional
  26. :raises ValueError: No config provided
  27. """
  28. if config is None:
  29. self.config = WeaviateDBConfig()
  30. else:
  31. if not isinstance(config, WeaviateDBConfig):
  32. raise TypeError(
  33. "config is not a `WeaviateDBConfig` instance. "
  34. "Please make sure the type is right and that you are passing an instance."
  35. )
  36. self.config = config
  37. self.client = weaviate.Client(
  38. url=os.environ.get("WEAVIATE_ENDPOINT"),
  39. auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
  40. **self.config.extra_params,
  41. )
  42. # Call parent init here because embedder is needed
  43. super().__init__(config=self.config)
  44. def _initialize(self):
  45. """
  46. This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
  47. """
  48. if not self.embedder:
  49. raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
  50. self.index_name = self._get_index_name()
  51. self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
  52. if not self.client.schema.exists(self.index_name):
  53. # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
  54. # The none vectorizer is crucial as we have our own custom embedding function
  55. class_obj = {
  56. "classes": [
  57. {
  58. "class": self.index_name,
  59. "vectorizer": "none",
  60. "properties": [
  61. {
  62. "name": "identifier",
  63. "dataType": ["text"],
  64. },
  65. {
  66. "name": "text",
  67. "dataType": ["text"],
  68. },
  69. {
  70. "name": "metadata",
  71. "dataType": [self.index_name + "_metadata"],
  72. },
  73. ],
  74. },
  75. {
  76. "class": self.index_name + "_metadata",
  77. "vectorizer": "none",
  78. "properties": [
  79. {
  80. "name": "data_type",
  81. "dataType": ["text"],
  82. },
  83. {
  84. "name": "doc_id",
  85. "dataType": ["text"],
  86. },
  87. {
  88. "name": "url",
  89. "dataType": ["text"],
  90. },
  91. {
  92. "name": "hash",
  93. "dataType": ["text"],
  94. },
  95. {
  96. "name": "app_id",
  97. "dataType": ["text"],
  98. },
  99. {
  100. "name": "text",
  101. "dataType": ["text"],
  102. },
  103. ],
  104. },
  105. ]
  106. }
  107. self.client.schema.create(class_obj)
  108. def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
  109. """
  110. Get existing doc ids present in vector database
  111. :param ids: _list of doc ids to check for existance
  112. :type ids: List[str]
  113. :param where: to filter data
  114. :type where: Dict[str, any]
  115. :return: ids
  116. :rtype: Set[str]
  117. """
  118. if ids is None or len(ids) == 0:
  119. return {"ids": []}
  120. existing_ids = []
  121. cursor = None
  122. has_iterated_once = False
  123. while cursor is not None or not has_iterated_once:
  124. has_iterated_once = True
  125. results = self._query_with_cursor(
  126. self.client.query.get(self.index_name, ["identifier"])
  127. .with_additional(["id"])
  128. .with_limit(self.BATCH_SIZE),
  129. cursor,
  130. )
  131. fetched_results = results["data"]["Get"].get(self.index_name, [])
  132. if len(fetched_results) == 0:
  133. break
  134. for result in fetched_results:
  135. existing_ids.append(result["identifier"])
  136. cursor = result["_additional"]["id"]
  137. return {"ids": existing_ids}
  138. def add(
  139. self,
  140. embeddings: List[List[float]],
  141. documents: List[str],
  142. metadatas: List[object],
  143. ids: List[str],
  144. skip_embedding: bool,
  145. ):
  146. """add data in vector database
  147. :param embeddings: list of embeddings for the corresponding documents to be added
  148. :type documents: List[List[float]]
  149. :param documents: list of texts to add
  150. :type documents: List[str]
  151. :param metadatas: list of metadata associated with docs
  152. :type metadatas: List[object]
  153. :param ids: ids of docs
  154. :type ids: List[str]
  155. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  156. generated or not
  157. :type skip_embedding: bool
  158. """
  159. print("Adding documents to Weaviate...")
  160. if not skip_embedding:
  161. embeddings = self.embedder.embedding_fn(documents)
  162. self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3) # Configure batch
  163. with self.client.batch as batch: # Initialize a batch process
  164. for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
  165. doc = {"identifier": id, "text": text}
  166. updated_metadata = {"text": text}
  167. if metadata is not None:
  168. updated_metadata.update(**metadata)
  169. obj_uuid = batch.add_data_object(
  170. data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding
  171. )
  172. metadata_uuid = batch.add_data_object(
  173. data_object=copy.deepcopy(updated_metadata),
  174. class_name=self.index_name + "_metadata",
  175. vector=embedding,
  176. )
  177. batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
  178. def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
  179. """
  180. query contents from vector database based on vector similarity
  181. :param input_query: list of query string
  182. :type input_query: List[str]
  183. :param n_results: no of similar documents to fetch from database
  184. :type n_results: int
  185. :param where: Optional. to filter data
  186. :type where: Dict[str, any]
  187. :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
  188. generated or not
  189. :type skip_embedding: bool
  190. :return: Database contents that are the result of the query
  191. :rtype: List[str]
  192. """
  193. if not skip_embedding:
  194. query_vector = self.embedder.embedding_fn([input_query])[0]
  195. else:
  196. query_vector = input_query
  197. keys = set(where.keys() if where is not None else set())
  198. if len(keys.intersection(self.metadata_keys)) != 0:
  199. weaviate_where_operands = []
  200. for key in keys:
  201. if key in self.metadata_keys:
  202. weaviate_where_operands.append(
  203. {
  204. "path": ["metadata", self.index_name + "_metadata", key],
  205. "operator": "Equal",
  206. "valueText": where.get(key),
  207. }
  208. )
  209. if len(weaviate_where_operands) == 1:
  210. weaviate_where_clause = weaviate_where_operands[0]
  211. else:
  212. weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
  213. results = (
  214. self.client.query.get(self.index_name, ["text"])
  215. .with_where(weaviate_where_clause)
  216. .with_near_vector({"vector": query_vector})
  217. .with_limit(n_results)
  218. .do()
  219. )
  220. else:
  221. results = (
  222. self.client.query.get(self.index_name, ["text"])
  223. .with_near_vector({"vector": query_vector})
  224. .with_limit(n_results)
  225. .do()
  226. )
  227. matched_tokens = []
  228. for result in results["data"]["Get"].get(self.index_name):
  229. matched_tokens.append(result["text"])
  230. return matched_tokens
  231. def set_collection_name(self, name: str):
  232. """
  233. Set the name of the collection. A collection is an isolated space for vectors.
  234. :param name: Name of the collection.
  235. :type name: str
  236. """
  237. if not isinstance(name, str):
  238. raise TypeError("Collection name must be a string")
  239. self.config.collection_name = name
  240. def count(self) -> int:
  241. """
  242. Count number of documents/chunks embedded in the database.
  243. :return: number of documents
  244. :rtype: int
  245. """
  246. data = self.client.query.aggregate(self.index_name).with_meta_count().do()
  247. return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"]
  248. def _get_or_create_db(self):
  249. """Called during initialization"""
  250. return self.client
  251. def reset(self):
  252. """
  253. Resets the database. Deletes all embeddings irreversibly.
  254. """
  255. # Delete all data from the database
  256. self.client.batch.delete_objects(
  257. self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
  258. )
  259. # Weaviate internally by default capitalizes the class name
  260. def _get_index_name(self) -> str:
  261. """Get the Weaviate index for a collection
  262. :return: Weaviate index
  263. :rtype: str
  264. """
  265. return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
  266. def _query_with_cursor(self, query, cursor):
  267. if cursor is not None:
  268. query.with_after(cursor)
  269. results = query.do()
  270. return results