Ver código fonte

Change list[str] -> str for vectordbs (#1388)

Anu 1 ano atrás
pai
commit
e0d0f6e94c

+ 3 - 3
embedchain/vectordb/chroma.py

@@ -183,7 +183,7 @@ class ChromaDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: list[str],
+        input_query: str,
         n_results: int,
         where: Optional[dict[str, any]] = None,
         raw_filter: Optional[dict[str, any]] = None,
@@ -193,8 +193,8 @@ class ChromaDB(BaseVectorDB):
         """
         Query contents from vector database based on vector similarity
 
-        :param input_query: list of query string
-        :type input_query: list[str]
+        :param input_query: query string
+        :type input_query: str
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: to filter data

+ 4 - 4
embedchain/vectordb/elasticsearch.py

@@ -163,7 +163,7 @@ class ElasticsearchDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: list[str],
+        input_query: str,
         n_results: int,
         where: dict[str, any],
         citations: bool = False,
@@ -172,8 +172,8 @@ class ElasticsearchDB(BaseVectorDB):
         """
         query contents from vector database based on vector similarity
 
-        :param input_query: list of query string
-        :type input_query: list[str]
+        :param input_query: query string
+        :type input_query: str
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data
@@ -185,7 +185,7 @@ class ElasticsearchDB(BaseVectorDB):
         along with url of the source and doc_id (if citations flag is true)
         :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
-        input_query_vector = self.embedder.embedding_fn(input_query)
+        input_query_vector = self.embedder.embedding_fn([input_query])
         query_vector = input_query_vector[0]
 
         # `https://www.elastic.co/guide/en/elasticsearch/reference/7.17/query-dsl-script-score-query.html`

+ 3 - 3
embedchain/vectordb/opensearch.py

@@ -146,7 +146,7 @@ class OpenSearchDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: list[str],
+        input_query: str,
         n_results: int,
         where: dict[str, any],
         citations: bool = False,
@@ -155,8 +155,8 @@ class OpenSearchDB(BaseVectorDB):
         """
         query contents from vector database based on vector similarity
 
-        :param input_query: list of query string
-        :type input_query: list[str]
+        :param input_query: query string
+        :type input_query: str
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data

+ 2 - 2
embedchain/vectordb/pinecone.py

@@ -150,7 +150,7 @@ class PineconeDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: list[str],
+        input_query: str,
         n_results: int,
         where: Optional[dict[str, any]] = None,
         raw_filter: Optional[dict[str, any]] = None,
@@ -162,7 +162,7 @@ class PineconeDB(BaseVectorDB):
         Query contents from vector database based on vector similarity.
 
         Args:
-            input_query (list[str]): List of query strings.
+            input_query (str): query string.
             n_results (int): Number of similar documents to fetch from the database.
             where (dict[str, any], optional): Filter criteria for the search.
             raw_filter (dict[str, any], optional): Advanced raw filter criteria for the search.

+ 3 - 3
embedchain/vectordb/qdrant.py

@@ -161,7 +161,7 @@ class QdrantDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: list[str],
+        input_query: str,
         n_results: int,
         where: dict[str, any],
         citations: bool = False,
@@ -169,8 +169,8 @@ class QdrantDB(BaseVectorDB):
     ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity
-        :param input_query: list of query string
-        :type input_query: list[str]
+        :param input_query: query string
+        :type input_query: str
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data

+ 3 - 3
embedchain/vectordb/weaviate.py

@@ -219,12 +219,12 @@ class WeaviateDB(BaseVectorDB):
                 )
 
     def query(
-        self, input_query: list[str], n_results: int, where: dict[str, any], citations: bool = False
+        self, input_query: str, n_results: int, where: dict[str, any], citations: bool = False
     ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity
-        :param input_query: list of query string
-        :type input_query: list[str]
+        :param input_query: query string
+        :type input_query: str
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data

+ 3 - 3
embedchain/vectordb/zilliz.py

@@ -138,7 +138,7 @@ class ZillizVectorDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: list[str],
+        input_query: str,
         n_results: int,
         where: dict[str, Any],
         citations: bool = False,
@@ -147,8 +147,8 @@ class ZillizVectorDB(BaseVectorDB):
         """
         Query contents from vector database based on vector similarity
 
-        :param input_query: list of query string
-        :type input_query: list[str]
+        :param input_query: query string
+        :type input_query: str
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: to filter data

+ 1 - 1
tests/vectordb/test_elasticsearch_db.py

@@ -58,7 +58,7 @@ class TestEsDB(unittest.TestCase):
         mock_client.return_value.search.return_value = search_response
 
         # Query the database for the documents that are most similar to the query "This is a document".
-        query = ["This is a document"]
+        query = "This is a document"
         results_without_citations = self.db.query(query, n_results=2, where={})
         expected_results_without_citations = ["This is a document.", "This is another document."]
         self.assertEqual(results_without_citations, expected_results_without_citations)

+ 1 - 1
tests/vectordb/test_qdrant.py

@@ -114,7 +114,7 @@ class TestQdrantDB(unittest.TestCase):
         App(config=app_config, db=db, embedding_model=embedder)
 
         # Query for the document.
-        db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
+        db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
 
         qdrant_client_mock.return_value.search.assert_called_once_with(
             collection_name="embedchain-store-1536",

+ 2 - 2
tests/vectordb/test_weaviate.py

@@ -161,7 +161,7 @@ class TestWeaviateDb(unittest.TestCase):
         App(config=app_config, db=db, embedding_model=embedder)
 
         # Query for the document.
-        db.query(input_query=["This is a test document."], n_results=1, where={})
+        db.query(input_query="This is a test document.", n_results=1, where={})
 
         weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
         weaviate_client_query_get_mock.with_near_vector.assert_called_once_with({"vector": [1, 2, 3]})
@@ -185,7 +185,7 @@ class TestWeaviateDb(unittest.TestCase):
         App(config=app_config, db=db, embedding_model=embedder)
 
         # Query for the document.
-        db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"})
+        db.query(input_query="This is a test document.", n_results=1, where={"doc_id": "123"})
 
         weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1536", ["text"])
         weaviate_client_query_get_mock.with_where.assert_called_once_with(

+ 2 - 2
tests/vectordb/test_zilliz_db.py

@@ -139,7 +139,7 @@ class TestZillizDBCollection:
                 ]
             ]
 
-            query_result = zilliz_db.query(input_query=["query_text"], n_results=1, where={})
+            query_result = zilliz_db.query(input_query="query_text", n_results=1, where={})
 
             # Assert that MilvusClient.search was called with the correct parameters
             mock_search.assert_called_with(
@@ -154,7 +154,7 @@ class TestZillizDBCollection:
             assert query_result == ["result_doc"]
 
             query_result_with_citations = zilliz_db.query(
-                input_query=["query_text"], n_results=1, where={}, citations=True
+                input_query="query_text", n_results=1, where={}, citations=True
             )
 
             mock_search.assert_called_with(