Просмотр исходного кода

[Feature] Add support for weaviate vector db (#782)

Rupesh Bansal 1 год назад
Родитель
Сommit
cdfd6519c8

+ 4 - 0
configs/weaviate.yaml

@@ -0,0 +1,4 @@
+vectordb:
+  provider: weaviate
+  config:
+    collection_name: my_weaviate_index

+ 16 - 1
docs/components/vector-databases.mdx

@@ -187,6 +187,21 @@ _Coming soon_
 
 ## Weaviate
 
-_Coming soon_
+In order to use Weaviate as a vector database, set the environment variables `WEAVIATE_ENDPOINT` and `WEAVIATE_API_KEY` which you can find on [Weaviate dashboard](https://console.weaviate.cloud/dashboard).
+
+```python main.py
+from embedchain import App
+
+# load weaviate configuration from yaml file
+app = App.from_config(yaml_path="config.yaml")
+```
+
+```yaml config.yaml
+vectordb:
+  provider: weaviate
+  config:
+    collection_name: my_weaviate_index
+```
+
 
 <Snippet file="missing-vector-db-tip.mdx" />

+ 16 - 0
embedchain/config/vectordb/weaviate.py

@@ -0,0 +1,16 @@
+from typing import Dict, Optional
+
+from embedchain.config.vectordb.base import BaseVectorDbConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class WeaviateDBConfig(BaseVectorDbConfig):
+    def __init__(
+        self,
+        collection_name: Optional[str] = None,
+        dir: Optional[str] = None,
+        **extra_params: Dict[str, any],
+    ):
+        self.extra_params = extra_params
+        super().__init__(collection_name=collection_name, dir=dir)

+ 0 - 3
embedchain/embedchain.py

@@ -359,7 +359,6 @@ class EmbedChain(JSONSerializable):
 
         db_result = self.db.get(ids=ids, where=where)  # optional filter
         existing_ids = set(db_result["ids"])
-
         if len(existing_ids):
             data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
             data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
@@ -436,7 +435,6 @@ class EmbedChain(JSONSerializable):
         :rtype: List[str]
         """
         query_config = config or self.llm.config
-
         if where is not None:
             where = where
         elif query_config is not None and query_config.where is not None:
@@ -463,7 +461,6 @@ class EmbedChain(JSONSerializable):
             where=where,
             skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
         )
-
         return contents
 
     def query(self, input_query: str, config: BaseLlmConfig = None, dry_run=False, where: Optional[Dict] = None) -> str:

+ 2 - 0
embedchain/factory.py

@@ -72,12 +72,14 @@ class VectorDBFactory:
         "elasticsearch": "embedchain.vectordb.elasticsearch.ElasticsearchDB",
         "opensearch": "embedchain.vectordb.opensearch.OpenSearchDB",
         "pinecone": "embedchain.vectordb.pinecone.PineconeDB",
+        "weaviate": "embedchain.vectordb.weaviate.WeaviateDB",
     }
     provider_to_config_class = {
         "chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
         "elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
         "opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
         "pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
+        "weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
     }
 
     @classmethod

+ 0 - 1
embedchain/llm/base.py

@@ -206,7 +206,6 @@ class BaseLlm(JSONSerializable):
                 k["web_search_result"] = self.access_search_and_get_results(input_query)
             prompt = self.generate_prompt(input_query, contexts, **k)
             logging.info(f"Prompt: {prompt}")
-
             if dry_run:
                 return prompt
 

+ 297 - 0
embedchain/vectordb/weaviate.py

@@ -0,0 +1,297 @@
+import copy
+import os
+from typing import Dict, List, Optional
+
+try:
+    import weaviate
+except ImportError:
+    raise ImportError(
+        "Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
+    ) from None
+
+from embedchain.config.vectordb.weaviate import WeaviateDBConfig
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.vectordb.base import BaseVectorDB
+
+
+@register_deserializable
+class WeaviateDB(BaseVectorDB):
+    """
+    Weaviate as vector database
+    """
+
+    BATCH_SIZE = 100
+
+    def __init__(
+        self,
+        config: Optional[WeaviateDBConfig] = None,
+    ):
+        """Weaviate as vector database.
+        :param config: Weaviate database config, defaults to None
+        :type config: WeaviateDBConfig, optional
+        :raises ValueError: No config provided
+        """
+        if config is None:
+            self.config = WeaviateDBConfig()
+        else:
+            if not isinstance(config, WeaviateDBConfig):
+                raise TypeError(
+                    "config is not a `WeaviateDBConfig` instance. "
+                    "Please make sure the type is right and that you are passing an instance."
+                )
+            self.config = config
+        self.client = weaviate.Client(
+            url=os.environ.get("WEAVIATE_ENDPOINT"),
+            auth_client_secret=weaviate.AuthApiKey(api_key=os.environ.get("WEAVIATE_API_KEY")),
+            **self.config.extra_params,
+        )
+
+        # Call parent init here because embedder is needed
+        super().__init__(config=self.config)
+
+    def _initialize(self):
+        """
+        This method is needed because `embedder` attribute needs to be set externally before it can be initialized.
+        """
+
+        if not self.embedder:
+            raise ValueError("Embedder not set. Please set an embedder with `set_embedder` before initialization.")
+
+        self.index_name = self._get_index_name()
+        self.metadata_keys = {"data_type", "doc_id", "url", "hash", "app_id", "text"}
+        if not self.client.schema.exists(self.index_name):
+            # id is a reserved field in Weaviate, hence we had to change the name of the id field to identifier
+            # The none vectorizer is crucial as we have our own custom embedding function
+            class_obj = {
+                "classes": [
+                    {
+                        "class": self.index_name,
+                        "vectorizer": "none",
+                        "properties": [
+                            {
+                                "name": "identifier",
+                                "dataType": ["text"],
+                            },
+                            {
+                                "name": "text",
+                                "dataType": ["text"],
+                            },
+                            {
+                                "name": "metadata",
+                                "dataType": [self.index_name + "_metadata"],
+                            },
+                        ],
+                    },
+                    {
+                        "class": self.index_name + "_metadata",
+                        "vectorizer": "none",
+                        "properties": [
+                            {
+                                "name": "data_type",
+                                "dataType": ["text"],
+                            },
+                            {
+                                "name": "doc_id",
+                                "dataType": ["text"],
+                            },
+                            {
+                                "name": "url",
+                                "dataType": ["text"],
+                            },
+                            {
+                                "name": "hash",
+                                "dataType": ["text"],
+                            },
+                            {
+                                "name": "app_id",
+                                "dataType": ["text"],
+                            },
+                            {
+                                "name": "text",
+                                "dataType": ["text"],
+                            },
+                        ],
+                    },
+                ]
+            }
+
+            self.client.schema.create(class_obj)
+
+    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+        """
+        Get existing doc ids present in vector database
+        :param ids: _list of doc ids to check for existance
+        :type ids: List[str]
+        :param where: to filter data
+        :type where: Dict[str, any]
+        :return: ids
+        :rtype: Set[str]
+        """
+
+        if ids is None or len(ids) == 0:
+            return {"ids": []}
+
+        existing_ids = []
+        cursor = None
+        has_iterated_once = False
+        while cursor is not None or not has_iterated_once:
+            has_iterated_once = True
+            results = self._query_with_cursor(
+                self.client.query.get(self.index_name, ["identifier"])
+                .with_additional(["id"])
+                .with_limit(self.BATCH_SIZE),
+                cursor,
+            )
+            fetched_results = results["data"]["Get"].get(self.index_name, [])
+            if len(fetched_results) == 0:
+                break
+            for result in fetched_results:
+                existing_ids.append(result["identifier"])
+                cursor = result["_additional"]["id"]
+
+        return {"ids": existing_ids}
+
+    def add(
+        self,
+        embeddings: List[List[float]],
+        documents: List[str],
+        metadatas: List[object],
+        ids: List[str],
+        skip_embedding: bool,
+    ):
+        """add data in vector database
+        :param embeddings: list of embeddings for the corresponding documents to be added
+        :type documents: List[List[float]]
+        :param documents: list of texts to add
+        :type documents: List[str]
+        :param metadatas: list of metadata associated with docs
+        :type metadatas: List[object]
+        :param ids: ids of docs
+        :type ids: List[str]
+        :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
+        generated or not
+        :type skip_embedding: bool
+        """
+
+        print("Adding documents to Weaviate...")
+        if not skip_embedding:
+            embeddings = self.embedder.embedding_fn(documents)
+        self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3)  # Configure batch
+        with self.client.batch as batch:  # Initialize a batch process
+            for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
+                doc = {"identifier": id, "text": text}
+                updated_metadata = {"text": text}
+                if metadata is not None:
+                    updated_metadata.update(**metadata)
+
+                obj_uuid = batch.add_data_object(
+                    data_object=copy.deepcopy(doc), class_name=self.index_name, vector=embedding
+                )
+                metadata_uuid = batch.add_data_object(
+                    data_object=copy.deepcopy(updated_metadata),
+                    class_name=self.index_name + "_metadata",
+                    vector=embedding,
+                )
+                batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
+
+    def query(self, input_query: List[str], n_results: int, where: Dict[str, any], skip_embedding: bool) -> List[str]:
+        """
+        query contents from vector database based on vector similarity
+        :param input_query: list of query string
+        :type input_query: List[str]
+        :param n_results: no of similar documents to fetch from database
+        :type n_results: int
+        :param where: Optional. to filter data
+        :type where: Dict[str, any]
+        :param skip_embedding: A boolean flag indicating if the embedding for the documents to be added is to be
+        generated or not
+        :type skip_embedding: bool
+        :return: Database contents that are the result of the query
+        :rtype: List[str]
+        """
+        if not skip_embedding:
+            query_vector = self.embedder.embedding_fn([input_query])[0]
+        else:
+            query_vector = input_query
+        keys = set(where.keys() if where is not None else set())
+        if len(keys.intersection(self.metadata_keys)) != 0:
+            weaviate_where_operands = []
+            for key in keys:
+                if key in self.metadata_keys:
+                    weaviate_where_operands.append(
+                        {
+                            "path": ["metadata", self.index_name + "_metadata", key],
+                            "operator": "Equal",
+                            "valueText": where.get(key),
+                        }
+                    )
+            if len(weaviate_where_operands) == 1:
+                weaviate_where_clause = weaviate_where_operands[0]
+            else:
+                weaviate_where_clause = {"operator": "And", "operands": weaviate_where_operands}
+
+            results = (
+                self.client.query.get(self.index_name, ["text"])
+                .with_where(weaviate_where_clause)
+                .with_near_vector({"vector": query_vector})
+                .with_limit(n_results)
+                .do()
+            )
+        else:
+            results = (
+                self.client.query.get(self.index_name, ["text"])
+                .with_near_vector({"vector": query_vector})
+                .with_limit(n_results)
+                .do()
+            )
+        matched_tokens = []
+        for result in results["data"]["Get"].get(self.index_name):
+            matched_tokens.append(result["text"])
+
+        return matched_tokens
+
+    def set_collection_name(self, name: str):
+        """
+        Set the name of the collection. A collection is an isolated space for vectors.
+        :param name: Name of the collection.
+        :type name: str
+        """
+        if not isinstance(name, str):
+            raise TypeError("Collection name must be a string")
+        self.config.collection_name = name
+
+    def count(self) -> int:
+        """
+        Count number of documents/chunks embedded in the database.
+        :return: number of documents
+        :rtype: int
+        """
+        data = self.client.query.aggregate(self.index_name).with_meta_count().do()
+        return data["data"]["Aggregate"].get(self.index_name)[0]["meta"]["count"]
+
+    def _get_or_create_db(self):
+        """Called during initialization"""
+        return self.client
+
+    def reset(self):
+        """
+        Resets the database. Deletes all embeddings irreversibly.
+        """
+        # Delete all data from the database
+        self.client.batch.delete_objects(
+            self.index_name, where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
+        )
+
+    # Weaviate internally by default capitalizes the class name
+    def _get_index_name(self) -> str:
+        """Get the Weaviate index for a collection
+        :return: Weaviate index
+        :rtype: str
+        """
+        return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
+
+    def _query_with_cursor(self, query, cursor):
+        if cursor is not None:
+            query.with_after(cursor)
+        results = query.do()
+        return results

+ 2 - 0
pyproject.toml

@@ -111,6 +111,7 @@ fastapi-poe = { version = "0.0.16", optional = true }
 discord = { version = "^2.3.2", optional = true }
 slack-sdk = { version = "3.21.3", optional = true }
 cohere = { version = "^4.27", optional= true }
+weaviate-client = { version = "^3.24.1", optional= true }
 docx2txt = { version="^0.8", optional=true }
 pinecone-client = { version = "^2.2.4", optional = true }
 unstructured = {extras = ["local-inference"], version = "^0.10.18", optional=true}
@@ -145,6 +146,7 @@ poe = ["fastapi-poe"]
 discord = ["discord"]
 slack = ["slack-sdk", "flask"]
 whatsapp = ["twilio", "flask"]
+weaviate = ["weaviate-client"]
 pinecone = ["pinecone-client"]
 images = ["torch", "ftfy", "regex", "pillow", "torchvision"]
 huggingface_hub=["huggingface_hub"]

+ 244 - 0
tests/vectordb/test_weaviate.py

@@ -0,0 +1,244 @@
+import unittest
+from unittest.mock import patch
+
+from embedchain import App
+from embedchain.config import AppConfig
+from embedchain.config.vectordb.pinecone import PineconeDBConfig
+from embedchain.embedder.base import BaseEmbedder
+from embedchain.vectordb.weaviate import WeaviateDB
+
+
+class TestWeaviateDb(unittest.TestCase):
+    def test_incorrect_config_throws_error(self):
+        """Test the init method of the WeaviateDb class throws error for incorrect config"""
+        with self.assertRaises(TypeError):
+            WeaviateDB(config=PineconeDBConfig())
+
+    @patch("embedchain.vectordb.weaviate.weaviate")
+    def test_initialize(self, weaviate_mock):
+        """Test the init method of the WeaviateDb class."""
+        weaviate_client_mock = weaviate_mock.Client.return_value
+        weaviate_client_schema_mock = weaviate_client_mock.schema
+
+        # Mock that schema doesn't already exist so that a new schema is created
+        weaviate_client_schema_mock.exists.return_value = False
+        # Set the embedder
+        embedder = BaseEmbedder()
+        embedder.set_vector_dimension(1526)
+
+        # Create a Weaviate instance
+        db = WeaviateDB()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+
+        expected_class_obj = {
+            "classes": [
+                {
+                    "class": "Embedchain_store_1526",
+                    "vectorizer": "none",
+                    "properties": [
+                        {
+                            "name": "identifier",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "text",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "metadata",
+                            "dataType": ["Embedchain_store_1526_metadata"],
+                        },
+                    ],
+                },
+                {
+                    "class": "Embedchain_store_1526_metadata",
+                    "vectorizer": "none",
+                    "properties": [
+                        {
+                            "name": "data_type",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "doc_id",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "url",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "hash",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "app_id",
+                            "dataType": ["text"],
+                        },
+                        {
+                            "name": "text",
+                            "dataType": ["text"],
+                        },
+                    ],
+                },
+            ]
+        }
+
+        # Assert that the Weaviate client was initialized
+        weaviate_mock.Client.assert_called_once()
+        self.assertEqual(db.index_name, "Embedchain_store_1526")
+        weaviate_client_schema_mock.create.assert_called_once_with(expected_class_obj)
+
+    @patch("embedchain.vectordb.weaviate.weaviate")
+    def test_get_or_create_db(self, weaviate_mock):
+        """Test the _get_or_create_db method of the WeaviateDb class."""
+        weaviate_client_mock = weaviate_mock.Client.return_value
+
+        embedder = BaseEmbedder()
+        embedder.set_vector_dimension(1526)
+
+        # Create a Weaviate instance
+        db = WeaviateDB()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+
+        expected_client = db._get_or_create_db()
+        self.assertEqual(expected_client, weaviate_client_mock)
+
+    @patch("embedchain.vectordb.weaviate.weaviate")
+    def test_add(self, weaviate_mock):
+        """Test the add method of the WeaviateDb class."""
+        weaviate_client_mock = weaviate_mock.Client.return_value
+        weaviate_client_batch_mock = weaviate_client_mock.batch
+        weaviate_client_batch_enter_mock = weaviate_client_mock.batch.__enter__.return_value
+
+        # Set the embedder
+        embedder = BaseEmbedder()
+        embedder.set_vector_dimension(1526)
+
+        # Create a Weaviate instance
+        db = WeaviateDB()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+        db.BATCH_SIZE = 1
+
+        embeddings = [[1, 2, 3], [4, 5, 6]]
+        documents = ["This is a test document.", "This is another test document."]
+        metadatas = [None, None]
+        ids = ["123", "456"]
+        skip_embedding = True
+        db.add(embeddings, documents, metadatas, ids, skip_embedding)
+
+        # Check if the document was added to the database.
+        weaviate_client_batch_mock.configure.assert_called_once_with(batch_size=1, timeout_retries=3)
+        weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
+            data_object={"text": documents[0]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[0]
+        )
+        weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
+            data_object={"text": documents[1]}, class_name="Embedchain_store_1526_metadata", vector=embeddings[1]
+        )
+
+        weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
+            data_object={"identifier": ids[0], "text": documents[0]},
+            class_name="Embedchain_store_1526",
+            vector=embeddings[0],
+        )
+        weaviate_client_batch_enter_mock.add_data_object.assert_any_call(
+            data_object={"identifier": ids[1], "text": documents[1]},
+            class_name="Embedchain_store_1526",
+            vector=embeddings[1],
+        )
+
+    @patch("embedchain.vectordb.weaviate.weaviate")
+    def test_query_without_where(self, weaviate_mock):
+        """Test the query method of the WeaviateDb class."""
+        weaviate_client_mock = weaviate_mock.Client.return_value
+        weaviate_client_query_mock = weaviate_client_mock.query
+        weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
+
+        # Set the embedder
+        embedder = BaseEmbedder()
+        embedder.set_vector_dimension(1526)
+
+        # Create a Weaviate instance
+        db = WeaviateDB()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+
+        # Query for the document.
+        db.query(input_query=["This is a test document."], n_results=1, where={}, skip_embedding=True)
+
+        weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
+        weaviate_client_query_get_mock.with_near_vector.assert_called_once_with(
+            {"vector": ["This is a test document."]}
+        )
+
+    @patch("embedchain.vectordb.weaviate.weaviate")
+    def test_query_with_where(self, weaviate_mock):
+        """Test the query method of the WeaviateDb class."""
+        weaviate_client_mock = weaviate_mock.Client.return_value
+        weaviate_client_query_mock = weaviate_client_mock.query
+        weaviate_client_query_get_mock = weaviate_client_query_mock.get.return_value
+        weaviate_client_query_get_where_mock = weaviate_client_query_get_mock.with_where.return_value
+
+        # Set the embedder
+        embedder = BaseEmbedder()
+        embedder.set_vector_dimension(1526)
+
+        # Create a Weaviate instance
+        db = WeaviateDB()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+
+        # Query for the document.
+        db.query(input_query=["This is a test document."], n_results=1, where={"doc_id": "123"}, skip_embedding=True)
+
+        weaviate_client_query_mock.get.assert_called_once_with("Embedchain_store_1526", ["text"])
+        weaviate_client_query_get_mock.with_where.assert_called_once_with(
+            {"operator": "Equal", "path": ["metadata", "Embedchain_store_1526_metadata", "doc_id"], "valueText": "123"}
+        )
+        weaviate_client_query_get_where_mock.with_near_vector.assert_called_once_with(
+            {"vector": ["This is a test document."]}
+        )
+
+    @patch("embedchain.vectordb.weaviate.weaviate")
+    def test_reset(self, weaviate_mock):
+        """Test the reset method of the WeaviateDb class."""
+        weaviate_client_mock = weaviate_mock.Client.return_value
+        weaviate_client_batch_mock = weaviate_client_mock.batch
+
+        # Set the embedder
+        embedder = BaseEmbedder()
+        embedder.set_vector_dimension(1526)
+
+        # Create a Weaviate instance
+        db = WeaviateDB()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+
+        # Reset the database.
+        db.reset()
+
+        weaviate_client_batch_mock.delete_objects.assert_called_once_with(
+            "Embedchain_store_1526", where={"path": ["identifier"], "operator": "Like", "valueText": ".*"}
+        )
+
+    @patch("embedchain.vectordb.weaviate.weaviate")
+    def test_count(self, weaviate_mock):
+        """Test the reset method of the WeaviateDb class."""
+        weaviate_client_mock = weaviate_mock.Client.return_value
+        weaviate_client_query = weaviate_client_mock.query
+
+        # Set the embedder
+        embedder = BaseEmbedder()
+        embedder.set_vector_dimension(1526)
+
+        # Create a Weaviate instance
+        db = WeaviateDB()
+        app_config = AppConfig(collect_metrics=False)
+        App(config=app_config, db=db, embedder=embedder)
+
+        # Reset the database.
+        db.count()
+
+        weaviate_client_query.aggregate.assert_called_once_with("Embedchain_store_1526")