浏览代码

[bugfix] fix elasticsearch db (#1183)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 年之前
父节点
当前提交
3cc3a0058d
共有 1 个文件被更改,包括 28 次插入8 次删除
  1. 28 8
      embedchain/vectordb/elasticsearch.py

+ 28 - 8
embedchain/vectordb/elasticsearch.py

@@ -99,14 +99,24 @@ class ElasticsearchDB(BaseVectorDB):
             query = {"bool": {"must": [{"ids": {"values": ids}}]}}
         else:
             query = {"bool": {"must": []}}
-        if "app_id" in where:
-            app_id = where["app_id"]
-            query["bool"]["must"].append({"term": {"metadata.app_id": app_id}})
 
-        response = self.client.search(index=self._get_index(), query=query, _source=False, size=limit)
+        if where:
+            for key, value in where.items():
+                query["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}})
+
+        response = self.client.search(index=self._get_index(), query=query, _source=True, size=limit)
         docs = response["hits"]["hits"]
         ids = [doc["_id"] for doc in docs]
-        return {"ids": set(ids)}
+        doc_ids = [doc["_source"]["metadata"]["doc_id"] for doc in docs]
+
+        # Result is modified for compatibility with other vector databases
+        # TODO: Add method in vector database to return result in a standard format
+        result = {"ids": ids, "metadatas": []}
+
+        for doc_id in doc_ids:
+            result["metadatas"].append({"doc_id": doc_id})
+
+        return result
 
     def add(
         self,
@@ -186,9 +196,11 @@ class ElasticsearchDB(BaseVectorDB):
                 },
             }
         }
-        if "app_id" in where:
-            app_id = where["app_id"]
-            query["script_score"]["query"] = {"match": {"metadata.app_id": app_id}}
+
+        if where:
+            for key, value in where.items():
+                query["script_score"]["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}})
+
         _source = ["text", "metadata"]
         response = self.client.search(index=self._get_index(), query=query, _source=_source, size=n_results)
         docs = response["hits"]["hits"]
@@ -244,3 +256,11 @@ class ElasticsearchDB(BaseVectorDB):
         # NOTE: The method is preferred to an attribute, because if collection name changes,
         # it's always up-to-date.
         return f"{self.config.collection_name}_{self.embedder.vector_dimension}".lower()
+
+    def delete(self, where):
+        """Delete documents from the database."""
+        query = {"query": {"bool": {"must": []}}}
+        for key, value in where.items():
+            query["query"]["bool"]["must"].append({"term": {f"metadata.{key}.keyword": value}})
+        self.client.delete_by_query(index=self._get_index(), body=query)
+        self.client.indices.refresh(index=self._get_index())