Pārlūkot izejas kodu

[Improvements] Package improvements (#993)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 gadu atpakaļ
vecāks
revīzija
51b4966801

+ 0 - 1
embedchain/chunkers/base_chunker.py

@@ -41,7 +41,6 @@ class BaseChunker(JSONSerializable):
             url = meta_data["url"]
 
             chunks = self.get_chunks(content)
-
             for chunk in chunks:
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id

+ 16 - 14
embedchain/data_formatter/data_formatter.py

@@ -1,5 +1,5 @@
 from importlib import import_module
-from typing import Any, Dict
+from typing import Optional
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig
@@ -16,7 +16,13 @@ class DataFormatter(JSONSerializable):
     .add or .add_local method call
     """
 
-    def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]):
+    def __init__(
+        self,
+        data_type: DataType,
+        config: AddConfig,
+        loader: Optional[BaseLoader] = None,
+        chunker: Optional[BaseChunker] = None,
+    ):
         """
         Initialize a dataformatter, set data type and chunker based on datatype.
 
@@ -25,15 +31,15 @@ class DataFormatter(JSONSerializable):
         :param config: AddConfig instance with nested loader and chunker config attributes.
         :type config: AddConfig
         """
-        self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs)
-        self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs)
+        self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
+        self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
 
     def _lazy_load(self, module_path: str):
         module_path, class_name = module_path.rsplit(".", 1)
         module = import_module(module_path)
         return getattr(module, class_name)
 
-    def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader:
+    def _get_loader(self, data_type: DataType, config: LoaderConfig, loader: Optional[BaseLoader]) -> BaseLoader:
         """
         Returns the appropriate data loader for the given data type.
 
@@ -68,8 +74,8 @@ class DataFormatter(JSONSerializable):
             DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
         }
 
-        if data_type == DataType.CUSTOM or ("loader" in kwargs):
-            loader_class: type = kwargs.get("loader", None)
+        if data_type == DataType.CUSTOM or loader is not None:
+            loader_class: type = loader
             if loader_class:
                 return loader_class
         elif data_type in loaders:
@@ -82,7 +88,7 @@ class DataFormatter(JSONSerializable):
                         check `https://docs.embedchain.ai/data-sources/overview`."
         )
 
-    def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker:
+    def _get_chunker(self, data_type: DataType, config: ChunkerConfig, chunker: Optional[BaseChunker]) -> BaseChunker:
         """Returns the appropriate chunker for the given data type (updated for lazy loading)."""
         chunker_classes = {
             DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
@@ -108,12 +114,8 @@ class DataFormatter(JSONSerializable):
             DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
         }
 
-        if "chunker" in kwargs:
-            chunker_class = kwargs.get("chunker", None)
-            if chunker_class:
-                chunker = chunker_class(config)
-                chunker.set_data_type(data_type)
-                return chunker
+        if chunker is not None:
+            return chunker
         elif data_type in chunker_classes:
             chunker_class = self._lazy_load(chunker_classes[data_type])
             chunker = chunker_class(config)

+ 28 - 10
embedchain/embedchain.py

@@ -133,7 +133,9 @@ class EmbedChain(JSONSerializable):
         metadata: Optional[Dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
         dry_run=False,
-        **kwargs: Dict[str, Any],
+        loader: Optional[BaseLoader] = None,
+        chunker: Optional[BaseChunker] = None,
+        **kwargs: Optional[Dict[str, Any]],
     ):
         """
         Adds the data from the given URL to the vector db.
@@ -192,9 +194,9 @@ class EmbedChain(JSONSerializable):
 
         self.user_asks.append([source, data_type.value, metadata])
 
-        data_formatter = DataFormatter(data_type, config, kwargs)
+        data_formatter = DataFormatter(data_type, config, loader, chunker)
         documents, metadatas, _ids, new_chunks = self._load_and_embed(
-            data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
+            data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
         )
         if data_type in {DataType.DOCS_SITE}:
             self.is_docs_site_instance = True
@@ -238,7 +240,7 @@ class EmbedChain(JSONSerializable):
         data_type: Optional[DataType] = None,
         metadata: Optional[Dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
-        **kwargs: Dict[str, Any],
+        **kwargs: Optional[Dict[str, Any]],
     ):
         """
         Adds the data from the given URL to the vector db.
@@ -269,7 +271,7 @@ class EmbedChain(JSONSerializable):
             data_type=data_type,
             metadata=metadata,
             config=config,
-            kwargs=kwargs,
+            **kwargs,
         )
 
     def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
@@ -338,6 +340,7 @@ class EmbedChain(JSONSerializable):
         metadata: Optional[Dict[str, Any]] = None,
         source_hash: Optional[str] = None,
         dry_run=False,
+        **kwargs: Optional[Dict[str, Any]],
     ):
         """
         Loads the data from the given URL, chunks it, and adds it to database.
@@ -431,6 +434,7 @@ class EmbedChain(JSONSerializable):
             metadatas=metadatas,
             ids=ids,
             skip_embedding=(chunker.data_type == DataType.IMAGES),
+            **kwargs,
         )
         count_new_chunks = self.db.count() - chunks_before_addition
 
@@ -448,7 +452,12 @@ class EmbedChain(JSONSerializable):
         ]
 
     def _retrieve_from_database(
-        self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None, citations: bool = False
+        self,
+        input_query: str,
+        config: Optional[BaseLlmConfig] = None,
+        where=None,
+        citations: bool = False,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         Queries the vector database based on the given input query.
@@ -492,6 +501,7 @@ class EmbedChain(JSONSerializable):
             where=where,
             skip_embedding=(hasattr(config, "query_type") and config.query_type == "Images"),
             citations=citations,
+            **kwargs,
         )
 
         return contexts
@@ -526,9 +536,13 @@ class EmbedChain(JSONSerializable):
         or the dry run result
         :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
-        citations = kwargs.get("citations", False)
+        if "citations" in kwargs:
+            citations = kwargs.pop("citations")
+        else:
+            citations = False
+
         contexts = self._retrieve_from_database(
-            input_query=input_query, config=config, where=where, citations=citations
+            input_query=input_query, config=config, where=where, citations=citations, **kwargs
         )
         if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
             contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))
@@ -579,9 +593,13 @@ class EmbedChain(JSONSerializable):
         or the dry run result
         :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
         """
-        citations = kwargs.get("citations", False)
+        if "citations" in kwargs:
+            citations = kwargs.pop("citations")
+        else:
+            citations = False
+
         contexts = self._retrieve_from_database(
-            input_query=input_query, config=config, where=where, citations=citations
+            input_query=input_query, config=config, where=where, citations=citations, **kwargs
         )
         if citations and len(contexts) > 0 and isinstance(contexts[0], tuple):
             contexts_data_for_llm_query = list(map(lambda x: x[0], contexts))

+ 0 - 1
embedchain/loaders/github.py

@@ -196,7 +196,6 @@ class GithubLoader(BaseLoader):
         logging.info(f"Total repos found: {repos_results.totalCount}")
         for repo_result in tqdm(repos_results, total=repos_results.totalCount, desc="Loading discussions from github"):
             teams = repo_result.get_teams()
-            # import pdb; pdb.set_trace()
             for team in teams:
                 team_discussions = team.get_discussions()
                 for discussion in team_discussions:

+ 15 - 0
embedchain/utils.py

@@ -1,3 +1,4 @@
+import itertools
 import json
 import logging
 import os
@@ -6,6 +7,7 @@ import string
 from typing import Any
 
 from schema import Optional, Or, Schema
+from tqdm import tqdm
 
 from embedchain.models.data_type import DataType
 
@@ -422,3 +424,16 @@ def validate_config(config_data):
     )
 
     return schema.validate(config_data)
+
+
+def chunks(iterable, batch_size=100, desc="Processing chunks"):
+    """A helper function to break an iterable into chunks of size batch_size."""
+    it = iter(iterable)
+    total_size = len(iterable)
+
+    with tqdm(total=total_size, desc=desc, unit="batch") as pbar:
+        chunk = tuple(itertools.islice(it, batch_size))
+        while chunk:
+            yield chunk
+            pbar.update(len(chunk))
+            chunk = tuple(itertools.islice(it, batch_size))

+ 4 - 0
embedchain/vectordb/chroma.py

@@ -133,6 +133,7 @@ class ChromaDB(BaseVectorDB):
         metadatas: List[object],
         ids: List[str],
         skip_embedding: bool,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Any:
         """
         Add vectors to chroma database
@@ -198,6 +199,7 @@ class ChromaDB(BaseVectorDB):
         where: Dict[str, any],
         skip_embedding: bool,
         citations: bool = False,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         Query contents from vector database based on vector similarity
@@ -225,6 +227,7 @@ class ChromaDB(BaseVectorDB):
                     ],
                     n_results=n_results,
                     where=self._generate_where_clause(where),
+                    **kwargs,
                 )
             else:
                 result = self.collection.query(
@@ -233,6 +236,7 @@ class ChromaDB(BaseVectorDB):
                     ],
                     n_results=n_results,
                     where=self._generate_where_clause(where),
+                    **kwargs,
                 )
         except InvalidDimensionException as e:
             raise InvalidDimensionException(

+ 2 - 0
embedchain/vectordb/elasticsearch.py

@@ -105,6 +105,7 @@ class ElasticsearchDB(BaseVectorDB):
         metadatas: List[object],
         ids: List[str],
         skip_embedding: bool,
+        **kwargs: Optional[Dict[str, any]],
     ) -> Any:
         """
         add data in vector database
@@ -142,6 +143,7 @@ class ElasticsearchDB(BaseVectorDB):
         where: Dict[str, any],
         skip_embedding: bool,
         citations: bool = False,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector data base based on vector similarity

+ 5 - 2
embedchain/vectordb/opensearch.py

@@ -1,6 +1,6 @@
 import logging
 import time
-from typing import Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
 
 from tqdm import tqdm
 
@@ -121,6 +121,7 @@ class OpenSearchDB(BaseVectorDB):
         metadatas: List[object],
         ids: List[str],
         skip_embedding: bool,
+        **kwargs: Optional[Dict[str, any]],
     ):
         """Add data in vector database.
 
@@ -154,7 +155,7 @@ class OpenSearchDB(BaseVectorDB):
             ]
 
             # Perform bulk operation
-            bulk(self.client, batch_entries)
+            bulk(self.client, batch_entries, **kwargs)
             self.client.indices.refresh(index=self._get_index())
 
             # Sleep to avoid rate limiting
@@ -167,6 +168,7 @@ class OpenSearchDB(BaseVectorDB):
         where: Dict[str, any],
         skip_embedding: bool,
         citations: bool = False,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector data base based on vector similarity
@@ -209,6 +211,7 @@ class OpenSearchDB(BaseVectorDB):
             metadata_field="metadata",
             pre_filter=pre_filter,
             k=n_results,
+            **kwargs,
         )
 
         contexts = []

+ 6 - 4
embedchain/vectordb/pinecone.py

@@ -10,6 +10,7 @@ except ImportError:
 
 from embedchain.config.vectordb.pinecone import PineconeDBConfig
 from embedchain.helpers.json_serializable import register_deserializable
+from embedchain.utils import chunks
 from embedchain.vectordb.base import BaseVectorDB
 
 
@@ -92,6 +93,7 @@ class PineconeDB(BaseVectorDB):
         metadatas: List[object],
         ids: List[str],
         skip_embedding: bool,
+        **kwargs: Optional[Dict[str, any]],
     ):
         """add data in vector database
 
@@ -104,7 +106,6 @@ class PineconeDB(BaseVectorDB):
         """
         docs = []
         print("Adding documents to Pinecone...")
-
         embeddings = self.embedder.embedding_fn(documents)
         for id, text, metadata, embedding in zip(ids, documents, metadatas, embeddings):
             docs.append(
@@ -115,8 +116,8 @@ class PineconeDB(BaseVectorDB):
                 }
             )
 
-        for i in range(0, len(docs), self.BATCH_SIZE):
-            self.client.upsert(docs[i : i + self.BATCH_SIZE])
+        for chunk in chunks(docs, self.BATCH_SIZE, desc="Adding chunks in batches..."):
+            self.client.upsert(chunk, **kwargs)
 
     def query(
         self,
@@ -125,6 +126,7 @@ class PineconeDB(BaseVectorDB):
         where: Dict[str, any],
         skip_embedding: bool,
         citations: bool = False,
+        **kwargs: Optional[Dict[str, any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector database based on vector similarity
@@ -146,7 +148,7 @@ class PineconeDB(BaseVectorDB):
             query_vector = self.embedder.embedding_fn([input_query])[0]
         else:
             query_vector = input_query
-        data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True)
+        data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)
         contexts = []
         for doc in data["matches"]:
             metadata = doc["metadata"]

+ 5 - 1
embedchain/vectordb/qdrant.py

@@ -1,7 +1,7 @@
 import copy
 import os
 import uuid
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 try:
     from qdrant_client import QdrantClient
@@ -127,6 +127,7 @@ class QdrantDB(BaseVectorDB):
         metadatas: List[object],
         ids: List[str],
         skip_embedding: bool,
+        **kwargs: Optional[Dict[str, any]],
     ):
         """add data in vector database
         :param embeddings: list of embeddings for the corresponding documents to be added
@@ -158,6 +159,7 @@ class QdrantDB(BaseVectorDB):
                     payloads=payloads[i : i + self.BATCH_SIZE],
                     vectors=embeddings[i : i + self.BATCH_SIZE],
                 ),
+                **kwargs,
             )
 
     def query(
@@ -167,6 +169,7 @@ class QdrantDB(BaseVectorDB):
         where: Dict[str, any],
         skip_embedding: bool,
         citations: bool = False,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector database based on vector similarity
@@ -208,6 +211,7 @@ class QdrantDB(BaseVectorDB):
             query_filter=models.Filter(must=qdrant_must_filters),
             query_vector=query_vector,
             limit=n_results,
+            **kwargs,
         )
 
         contexts = []

+ 6 - 2
embedchain/vectordb/weaviate.py

@@ -1,6 +1,6 @@
 import copy
 import os
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 try:
     import weaviate
@@ -158,6 +158,7 @@ class WeaviateDB(BaseVectorDB):
         metadatas: List[object],
         ids: List[str],
         skip_embedding: bool,
+        **kwargs: Optional[Dict[str, any]],
     ):
         """add data in vector database
         :param embeddings: list of embeddings for the corresponding documents to be added
@@ -192,7 +193,9 @@ class WeaviateDB(BaseVectorDB):
                     class_name=self.index_name + "_metadata",
                     vector=embedding,
                 )
-                batch.add_reference(obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata")
+                batch.add_reference(
+                    obj_uuid, self.index_name, "metadata", metadata_uuid, self.index_name + "_metadata", **kwargs
+                )
 
     def query(
         self,
@@ -201,6 +204,7 @@ class WeaviateDB(BaseVectorDB):
         where: Dict[str, any],
         skip_embedding: bool,
         citations: bool = False,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         query contents from vector database based on vector similarity

+ 6 - 2
embedchain/vectordb/zilliz.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 from embedchain.config import ZillizDBConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -113,6 +113,7 @@ class ZillizVectorDB(BaseVectorDB):
         metadatas: List[object],
         ids: List[str],
         skip_embedding: bool,
+        **kwargs: Optional[Dict[str, any]],
     ):
         """Add to database"""
         if not skip_embedding:
@@ -120,7 +121,7 @@ class ZillizVectorDB(BaseVectorDB):
 
         for id, doc, metadata, embedding in zip(ids, documents, metadatas, embeddings):
             data = {**metadata, "id": id, "text": doc, "embeddings": embedding}
-            self.client.insert(collection_name=self.config.collection_name, data=data)
+            self.client.insert(collection_name=self.config.collection_name, data=data, **kwargs)
 
         self.collection.load()
         self.collection.flush()
@@ -133,6 +134,7 @@ class ZillizVectorDB(BaseVectorDB):
         where: Dict[str, any],
         skip_embedding: bool,
         citations: bool = False,
+        **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, str, str]], List[str]]:
         """
         Query contents from vector data base based on vector similarity
@@ -165,6 +167,7 @@ class ZillizVectorDB(BaseVectorDB):
                 data=query_vector,
                 limit=n_results,
                 output_fields=output_fields,
+                **kwargs,
             )
 
         else:
@@ -176,6 +179,7 @@ class ZillizVectorDB(BaseVectorDB):
                 data=[query_vector],
                 limit=n_results,
                 output_fields=output_fields,
+                **kwargs,
             )
 
         contexts = []

+ 3 - 3
tests/vectordb/test_pinecone.py

@@ -57,11 +57,11 @@ class TestPinecone:
         db.add(vectors, documents, metadatas, ids, True)
 
         expected_pinecone_upsert_args = [
-            {"id": "doc1", "metadata": {"text": "This is a document."}, "values": [0, 0, 0]},
-            {"id": "doc2", "metadata": {"text": "This is another document."}, "values": [1, 1, 1]},
+            {"id": "doc1", "values": [0, 0, 0], "metadata": {"text": "This is a document."}},
+            {"id": "doc2", "values": [1, 1, 1], "metadata": {"text": "This is another document."}},
         ]
         # Assert that the Pinecone client was called to upsert the documents
-        pinecone_client_mock.upsert.assert_called_once_with(expected_pinecone_upsert_args)
+        pinecone_client_mock.upsert.assert_called_once_with(tuple(expected_pinecone_upsert_args))
 
     @patch("embedchain.vectordb.pinecone.pinecone")
     def test_query_documents(self, pinecone_mock):