瀏覽代碼

[Bugfix] fix opensearch db (#1184)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 年之前
父節點
當前提交
f50f8a444a
共有 1 個文件被更改,包括 11 次插入10 次删除
  1. 11 10
      embedchain/vectordb/opensearch.py

+ 11 - 10
embedchain/vectordb/opensearch.py

@@ -96,9 +96,9 @@ class OpenSearchDB(BaseVectorDB):
         else:
             query["query"] = {"bool": {"must": []}}
 
-        if "app_id" in where:
-            app_id = where["app_id"]
-            query["query"]["bool"]["must"].append({"term": {"metadata.app_id.keyword": app_id}})
+        if where:
+            for key, value in where.items():
+                query["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}})
 
         # OpenSearch syntax is different from Elasticsearch
         response = self.client.search(index=self._get_index(), body=query, _source=True, size=limit)
@@ -176,9 +176,11 @@ class OpenSearchDB(BaseVectorDB):
         )
 
         pre_filter = {"match_all": {}}  # default
-        if "app_id" in where:
-            app_id = where["app_id"]
-            pre_filter = {"bool": {"must": [{"term": {"metadata.app_id.keyword": app_id}}]}}
+        if len(where) > 0:
+            pre_filter = {"bool": {"must": []}}
+            for key, value in where.items():
+                pre_filter["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}})
+
         docs = docsearch.similarity_search_with_score(
             input_query,
             search_type="script_scoring",
@@ -236,10 +238,9 @@ class OpenSearchDB(BaseVectorDB):
 
     def delete(self, where):
         """Deletes a document from the OpenSearch index"""
-        if "doc_id" not in where:
-            raise ValueError("doc_id is required to delete a document")
-
-        query = {"query": {"bool": {"must": [{"term": {"metadata.doc_id": where["doc_id"]}}]}}}
+        query = {"query": {"bool": {"must": []}}}
+        for key, value in where.items():
+            query["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}})
         self.client.delete_by_query(index=self._get_index(), body=query)
 
     def _get_index(self) -> str: