소스 검색

[Bug Fix] fix chromadb where clause for query and delete (#937)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 년 전
부모
커밋
deaa7f50f8
5개의 변경된 파일15개의 추가작업 그리고 13개의 파일을 삭제
  1. 2 1
      embedchain/apps/app.py
  2. 3 2
      embedchain/bots/base.py
  3. 5 5
      embedchain/embedchain.py
  4. 4 4
      embedchain/vectordb/chroma.py
  5. 1 1
      pyproject.toml

+ 2 - 1
embedchain/apps/app.py

@@ -3,7 +3,8 @@ from typing import Optional
 import yaml
 
 from embedchain.client import Client
-from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
+from embedchain.config import (AppConfig, BaseEmbedderConfig, BaseLlmConfig,
+                               ChunkerConfig)
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder

+ 3 - 2
embedchain/bots/base.py

@@ -1,9 +1,10 @@
 from typing import Any
 
 from embedchain import Pipeline as App
-from embedchain.config import AddConfig, PipelineConfig, BaseLlmConfig
+from embedchain.config import AddConfig, BaseLlmConfig, PipelineConfig
 from embedchain.embedder.openai import OpenAIEmbedder
-from embedchain.helper.json_serializable import JSONSerializable, register_deserializable
+from embedchain.helper.json_serializable import (JSONSerializable,
+                                                 register_deserializable)
 from embedchain.llm.openai import OpenAILlm
 from embedchain.vectordb.chroma import ChromaDB
 

+ 5 - 5
embedchain/embedchain.py

@@ -478,13 +478,13 @@ class EmbedChain(JSONSerializable):
         query_config = config or self.llm.config
         if where is not None:
             where = where
-        elif query_config is not None and query_config.where is not None:
-            where = query_config.where
         else:
             where = {}
-
-        if self.config.id is not None:
-            where.update({"app_id": self.config.id})
+            if query_config is not None and query_config.where is not None:
+                where = query_config.where
+            
+            if self.config.id is not None:
+                where.update({"app_id": self.config.id})
 
         # We cannot query the database with the input query in case of an image search. This is because we need
         # to bring down both the image and text to the same dimension to be able to compare them.

+ 4 - 4
embedchain/vectordb/chroma.py

@@ -77,7 +77,7 @@ class ChromaDB(BaseVectorDB):
     def _generate_where_clause(self, where: Dict[str, any]) -> str:
         # If only one filter is supplied, return it as is
         # (no need to wrap in $and based on chroma docs)
-        if len(where.keys()) == 1:
+        if len(where.keys()) <= 1:
             return where
         where_filters = []
         for k, v in where.items():
@@ -224,7 +224,7 @@ class ChromaDB(BaseVectorDB):
                         input_query,
                     ],
                     n_results=n_results,
-                    where=where,
+                    where=self._generate_where_clause(where),
                 )
             else:
                 result = self.collection.query(
@@ -232,7 +232,7 @@ class ChromaDB(BaseVectorDB):
                         input_query,
                     ],
                     n_results=n_results,
-                    where=where,
+                    where=self._generate_where_clause(where),
                 )
         except InvalidDimensionException as e:
             raise InvalidDimensionException(
@@ -275,7 +275,7 @@ class ChromaDB(BaseVectorDB):
         return self.collection.count()
 
     def delete(self, where):
-        return self.collection.delete(where=where)
+        return self.collection.delete(where=self._generate_where_clause(where))
 
     def reset(self):
         """

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.5"
+version = "0.1.6"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",