|
@@ -45,6 +45,9 @@ class WeaviateDB(BaseVectorDB):
|
|
|
auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
|
|
|
**self.config.extra_params,
|
|
|
)
|
|
|
+ # Since weaviate uses graphQL, we need to keep track of metadata keys added in the vectordb.
|
|
|
+ # This is needed to filter data while querying.
|
|
|
+ self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
|
|
|
|
|
|
# Call parent init here because embedder is needed
|
|
|
super().__init__(config=self.config)
|
|
@@ -58,7 +61,6 @@ class WeaviateDB(BaseVectorDB):
|
|
|
raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
|
|
|
|
|
|
self.index_name = self._get_index_name()
|
|
|
- self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id"}
|
|
|
if not self.client.schema.exists(self.index_name):
|
|
|
# id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
|
|
|
# The none vectorizer is crucial as we have our own custom embedding function
|
|
@@ -127,29 +129,64 @@ class WeaviateDB(BaseVectorDB):
|
|
|
:return: ids
|
|
|
:rtype: Set[str]
|
|
|
"""
|
|
|
+ weaviate_where_operands = []
|
|
|
+
|
|
|
+ if ids:
|
|
|
+ for doc_id in ids:
|
|
|
+ weaviate_where_operands.append({"path": ["identifier"], "operator": "Equal", "valueText": doc_id})
|
|
|
+
|
|
|
+ keys = set(where.keys() if where is not None else set())
|
|
|
+ if len(keys) > 0:
|
|
|
+ for key in keys:
|
|
|
+ weaviate_where_operands.append(
|
|
|
+ {
|
|
|
+ "path": ["metadata", self.index_name + "_metadata", key],
|
|
|
+ "operator": "Equal",
|
|
|
+ "valueText": where.get(key),
|
|
|
+ }
|
|
|
+ )
|
|
|
|
|
|
- if ids is None or len(ids) == 0:
|
|
|
- return {"ids": []}
|
|
|
+ if len(weaviate_where_operands) == 1:
|
|
|
+ weaviate_where_clause = weaviate_where_operands[0]
|
|
|
+ else:
|
|
|
+ weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
|
|
|
|
|
|
existing_ids = []
|
|
|
+ metadatas = []
|
|
|
cursor = None
|
|
|
+ offset = 0
|
|
|
has_iterated_once = False
|
|
|
+ query_metadata_keys = self.metadata_keys.union(keys)
|
|
|
while cursor is not None or not has_iterated_once:
|
|
|
has_iterated_once = True
|
|
|
- results = self._query_with_cursor(
|
|
|
- self.client.query.get(self.index_name, ["identifier"])
|
|
|
+ results = self._query_with_offset(
|
|
|
+ self.client.query.get(
|
|
|
+ self.index_name,
|
|
|
+ [
|
|
|
+ "identifier",
|
|
|
+ weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)),
|
|
|
+ ],
|
|
|
+ )
|
|
|
+ .with_where(weaviate_where_clause)
|
|
|
.with_additional(["id"])
|
|
|
- .with_limit(self.BATCH_SIZE),
|
|
|
- cursor,
|
|
|
+ .with_limit(limit or self.BATCH_SIZE),
|
|
|
+ offset,
|
|
|
)
|
|
|
+
|
|
|
fetched_results = results["data"]["Get"].get(self.index_name, [])
|
|
|
- if len(fetched_results) == 0:
|
|
|
+ if not fetched_results:
|
|
|
break
|
|
|
+
|
|
|
for result in fetched_results:
|
|
|
existing_ids.append(result["identifier"])
|
|
|
+ metadatas.append(result["metadata"][0])
|
|
|
cursor = result["_additional"]["id"]
|
|
|
+ offset += 1
|
|
|
|
|
|
- return {"ids": existing_ids}
|
|
|
+ if limit is not None and len(existing_ids) >= limit:
|
|
|
+ break
|
|
|
+
|
|
|
+ return {"ids": existing_ids, "metadatas": metadatas}
|
|
|
|
|
|
def add(self, documents: list[str], metadatas: list[object], ids: list[str], **kwargs: Optional[dict[str, any]]):
|
|
|
"""add data in vector database
|
|
@@ -201,21 +238,20 @@ class WeaviateDB(BaseVectorDB):
|
|
|
query_vector = self.embedder.embedding_fn([input_query])[0]
|
|
|
keys = set(where.keys() if where is not None else set())
|
|
|
data_fields = ["text"]
|
|
|
-
|
|
|
+ query_metadata_keys = self.metadata_keys.union(keys)
|
|
|
if citations:
|
|
|
- data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(self.metadata_keys)))
|
|
|
+ data_fields.append(weaviate.LinkTo("metadata", self.index_name + "_metadata", list(query_metadata_keys)))
|
|
|
|
|
|
- if len(keys.intersection(self.metadata_keys)) != 0:
|
|
|
+ if len(keys) > 0:
|
|
|
weaviate_where_operands = []
|
|
|
for key in keys:
|
|
|
- if key in self.metadata_keys:
|
|
|
- weaviate_where_operands.append(
|
|
|
- {
|
|
|
- "path": ["metadata", self.index_name + "_metadata", key],
|
|
|
- "operator": "Equal",
|
|
|
- "valueText": where.get(key),
|
|
|
- }
|
|
|
- )
|
|
|
+ weaviate_where_operands.append(
|
|
|
+ {
|
|
|
+ "path": ["metadata", self.index_name + "_metadata", key],
|
|
|
+ "operator": "Equal",
|
|
|
+ "valueText": where.get(key),
|
|
|
+ }
|
|
|
+ )
|
|
|
if len(weaviate_where_operands) == 1:
|
|
|
weaviate_where_clause = weaviate_where_operands[0]
|
|
|
else:
|
|
@@ -289,11 +325,37 @@ class WeaviateDB(BaseVectorDB):
|
|
|
:return: Weaviate index
|
|
|
:rtype: str
|
|
|
"""
|
|
|
- return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
|
|
|
+ return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize().replace("-", "_")
|
|
|
|
|
|
@staticmethod
|
|
|
- def _query_with_cursor(query, cursor):
|
|
|
- if cursor is not None:
|
|
|
- query.with_after(cursor)
|
|
|
+ def _query_with_offset(query, offset):
|
|
|
+ if offset:
|
|
|
+ query.with_offset(offset)
|
|
|
results = query.do()
|
|
|
return results
|
|
|
+
|
|
|
+ def _generate_query(self, where: dict):
|
|
|
+ weaviate_where_operands = []
|
|
|
+ for key, value in where.items():
|
|
|
+ weaviate_where_operands.append(
|
|
|
+ {
|
|
|
+ "path": ["metadata", self.index_name + "_metadata", key],
|
|
|
+ "operator": "Equal",
|
|
|
+ "valueText": value,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ if len(weaviate_where_operands) == 1:
|
|
|
+ weaviate_where_clause = weaviate_where_operands[0]
|
|
|
+ else:
|
|
|
+ weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
|
|
|
+
|
|
|
+ return weaviate_where_clause
|
|
|
+
|
|
|
+ def delete(self, where: dict):
|
|
|
+ """Delete from database.
|
|
|
+ :param where: to filter data
|
|
|
+ :type where: dict[str, any]
|
|
|
+ """
|
|
|
+ query = self._generate_query(where)
|
|
|
+ self.client.batch.delete_objects(self.index_name, where=query)
|