Browse Source

#1128 | Remove deprecated type hints from typing module (#1131)

Sandra Serrano 1 year ago
parent
commit
0de9491c61
41 changed files with 272 additions and 267 deletions
  1. 3 3
      embedchain/app.py
  2. 3 3
      embedchain/bots/poe.py
  3. 2 2
      embedchain/cache.py
  4. 2 1
      embedchain/config/add_config.py
  5. 3 3
      embedchain/config/base_config.py
  6. 4 4
      embedchain/config/cache_config.py
  7. 6 6
      embedchain/config/llm/base.py
  8. 6 6
      embedchain/config/vectordb/elasticsearch.py
  9. 4 4
      embedchain/config/vectordb/opensearch.py
  10. 2 2
      embedchain/config/vectordb/pinecone.py
  11. 6 6
      embedchain/config/vectordb/qdrant.py
  12. 2 2
      embedchain/config/vectordb/weaviate.py
  13. 25 25
      embedchain/embedchain.py
  14. 2 1
      embedchain/embedder/base.py
  15. 2 2
      embedchain/helpers/callbacks.py
  16. 3 3
      embedchain/helpers/json_serializable.py
  17. 11 10
      embedchain/llm/base.py
  18. 2 1
      embedchain/llm/google.py
  19. 2 1
      embedchain/llm/gpt4all.py
  20. 2 1
      embedchain/llm/ollama.py
  21. 2 2
      embedchain/llm/openai.py
  22. 2 2
      embedchain/loaders/directory_loader.py
  23. 2 2
      embedchain/loaders/discourse.py
  24. 1 2
      embedchain/loaders/dropbox.py
  25. 2 2
      embedchain/loaders/github.py
  26. 5 5
      embedchain/loaders/gmail.py
  27. 4 4
      embedchain/loaders/json.py
  28. 3 3
      embedchain/loaders/mysql.py
  29. 3 3
      embedchain/loaders/notion.py
  30. 3 3
      embedchain/loaders/postgres.py
  31. 3 3
      embedchain/loaders/slack.py
  32. 3 3
      embedchain/memory/base.py
  33. 3 3
      embedchain/memory/message.py
  34. 5 5
      embedchain/memory/utils.py
  35. 22 22
      embedchain/vectordb/chroma.py
  36. 20 20
      embedchain/vectordb/elasticsearch.py
  37. 22 22
      embedchain/vectordb/opensearch.py
  38. 19 19
      embedchain/vectordb/pinecone.py
  39. 20 20
      embedchain/vectordb/qdrant.py
  40. 20 20
      embedchain/vectordb/weaviate.py
  41. 16 16
      embedchain/vectordb/zilliz.py

+ 3 - 3
embedchain/app.py

@@ -4,7 +4,7 @@ import logging
 import os
 import sqlite3
 import uuid
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 import requests
 import yaml
@@ -364,7 +364,7 @@ class App(EmbedChain):
     def from_config(
         cls,
         config_path: Optional[str] = None,
-        config: Optional[Dict[str, Any]] = None,
+        config: Optional[dict[str, Any]] = None,
         auto_deploy: bool = False,
         yaml_path: Optional[str] = None,
     ):
@@ -374,7 +374,7 @@ class App(EmbedChain):
         :param config_path: Path to the YAML or JSON configuration file.
         :type config_path: Optional[str]
         :param config: A dictionary containing the configuration.
-        :type config: Optional[Dict[str, Any]]
+        :type config: Optional[dict[str, Any]]
         :param auto_deploy: Whether to deploy the pipeline automatically, defaults to False
         :type auto_deploy: bool, optional
         :param yaml_path: (Deprecated) Path to the YAML configuration file. Use config_path instead.

+ 3 - 3
embedchain/bots/poe.py

@@ -1,7 +1,7 @@
 import argparse
 import logging
 import os
-from typing import List, Optional
+from typing import Optional
 
 from embedchain.helpers.json_serializable import register_deserializable
 
@@ -53,7 +53,7 @@ class PoeBot(BaseBot, PoeBot):
         answer = self.handle_message(last_message, history)
         yield self.text_event(answer)
 
-    def handle_message(self, message, history: Optional[List[str]] = None):
+    def handle_message(self, message, history: Optional[list[str]] = None):
         if message.startswith("/add "):
             response = self.add_data(message)
         else:
@@ -70,7 +70,7 @@ class PoeBot(BaseBot, PoeBot):
     #         response = "Some error occurred while adding data."
     #     return response
 
-    def ask_bot(self, message, history: List[str]):
+    def ask_bot(self, message, history: list[str]):
         try:
             self.app.llm.set_history(history=history)
             response = self.query(message)

+ 2 - 2
embedchain/cache.py

@@ -1,6 +1,6 @@
 import logging
 import os  # noqa: F401
-from typing import Any, Dict
+from typing import Any
 
 from gptcache import cache  # noqa: F401
 from gptcache.adapter.adapter import adapt  # noqa: F401
@@ -15,7 +15,7 @@ from gptcache.similarity_evaluation.exact_match import \
     ExactMatchEvaluation  # noqa: F401
 
 
-def gptcache_pre_function(data: Dict[str, Any], **params: Dict[str, Any]):
+def gptcache_pre_function(data: dict[str, Any], **params: dict[str, Any]):
     return data["input_query"]
 
 

+ 2 - 1
embedchain/config/add_config.py

@@ -1,7 +1,8 @@
 import builtins
 import logging
+from collections.abc import Callable
 from importlib import import_module
-from typing import Callable, Optional
+from typing import Optional
 
 from embedchain.config.base_config import BaseConfig
 from embedchain.helpers.json_serializable import register_deserializable

+ 3 - 3
embedchain/config/base_config.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict
+from typing import Any
 
 from embedchain.helpers.json_serializable import JSONSerializable
 
@@ -12,10 +12,10 @@ class BaseConfig(JSONSerializable):
         """Initializes a configuration class for a class."""
         pass
 
-    def as_dict(self) -> Dict[str, Any]:
+    def as_dict(self) -> dict[str, Any]:
         """Return config object as a dict
 
         :return: config object as dict
-        :rtype: Dict[str, Any]
+        :rtype: dict[str, Any]
         """
         return vars(self)

+ 4 - 4
embedchain/config/cache_config.py

@@ -1,4 +1,4 @@
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 from embedchain.config.base_config import BaseConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -30,7 +30,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
         self.positive = positive
 
     @staticmethod
-    def from_config(config: Optional[Dict[str, Any]]):
+    def from_config(config: Optional[dict[str, Any]]):
         if config is None:
             return CacheSimilarityEvalConfig()
         else:
@@ -65,7 +65,7 @@ class CacheInitConfig(BaseConfig):
         self.auto_flush = auto_flush
 
     @staticmethod
-    def from_config(config: Optional[Dict[str, Any]]):
+    def from_config(config: Optional[dict[str, Any]]):
         if config is None:
             return CacheInitConfig()
         else:
@@ -86,7 +86,7 @@ class CacheConfig(BaseConfig):
         self.init_config = init_config
 
     @staticmethod
-    def from_config(config: Optional[Dict[str, Any]]):
+    def from_config(config: Optional[dict[str, Any]]):
         if config is None:
             return CacheConfig()
         else:

+ 6 - 6
embedchain/config/llm/base.py

@@ -1,7 +1,7 @@
 import logging
 import re
 from string import Template
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
 
 from embedchain.config.base_config import BaseConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -68,12 +68,12 @@ class BaseLlmConfig(BaseConfig):
         stream: bool = False,
         deployment_name: Optional[str] = None,
         system_prompt: Optional[str] = None,
-        where: Dict[str, Any] = None,
+        where: dict[str, Any] = None,
         query_type: Optional[str] = None,
-        callbacks: Optional[List] = None,
+        callbacks: Optional[list] = None,
         api_key: Optional[str] = None,
         endpoint: Optional[str] = None,
-        model_kwargs: Optional[Dict[str, Any]] = None,
+        model_kwargs: Optional[dict[str, Any]] = None,
     ):
         """
         Initializes a configuration class instance for the LLM.
@@ -106,7 +106,7 @@ class BaseLlmConfig(BaseConfig):
         :param system_prompt: System prompt string, defaults to None
         :type system_prompt: Optional[str], optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
-        :type where: Dict[str, Any], optional
+        :type where: dict[str, Any], optional
         :param api_key: The api key of the custom endpoint, defaults to None
         :type api_key: Optional[str], optional
         :param endpoint: The api url of the custom endpoint, defaults to None
@@ -114,7 +114,7 @@ class BaseLlmConfig(BaseConfig):
         :param model_kwargs: A dictionary of key-value pairs to pass to the model, defaults to None
         :type model_kwargs: Optional[Dict[str, Any]], optional
         :param callbacks: Langchain callback functions to use, defaults to None
-        :type callbacks: Optional[List], optional
+        :type callbacks: Optional[list], optional
         :param query_type: The type of query to use, defaults to None
         :type query_type: Optional[str], optional
         :raises ValueError: If the template is not valid as template should

+ 6 - 6
embedchain/config/vectordb/elasticsearch.py

@@ -1,5 +1,5 @@
 import os
-from typing import Dict, List, Optional, Union
+from typing import Optional, Union
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -11,9 +11,9 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
         self,
         collection_name: Optional[str] = None,
         dir: Optional[str] = None,
-        es_url: Union[str, List[str]] = None,
+        es_url: Union[str, list[str]] = None,
         cloud_id: Optional[str] = None,
-        **ES_EXTRA_PARAMS: Dict[str, any],
+        **ES_EXTRA_PARAMS: dict[str, any],
     ):
         """
         Initializes a configuration class instance for an Elasticsearch client.
@@ -23,13 +23,13 @@ class ElasticsearchDBConfig(BaseVectorDbConfig):
         :param dir: Path to the database directory, where the database is stored, defaults to None
         :type dir: Optional[str], optional
         :param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
-        :type es_url: Union[str, List[str]], optional
+        :type es_url: Union[str, list[str]], optional
         :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
-        :type ES_EXTRA_PARAMS: Dict[str, Any], optional
+        :type ES_EXTRA_PARAMS: dict[str, Any], optional
         """
         if es_url and cloud_id:
             raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
-        # self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
+        # self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
         self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
         self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
         if not self.ES_URL and not self.CLOUD_ID:

+ 4 - 4
embedchain/config/vectordb/opensearch.py

@@ -1,4 +1,4 @@
-from typing import Dict, Optional, Tuple
+from typing import Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -9,11 +9,11 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
     def __init__(
         self,
         opensearch_url: str,
-        http_auth: Tuple[str, str],
+        http_auth: tuple[str, str],
         vector_dimension: int = 1536,
         collection_name: Optional[str] = None,
         dir: Optional[str] = None,
-        **extra_params: Dict[str, any],
+        **extra_params: dict[str, any],
     ):
         """
         Initializes a configuration class instance for an OpenSearch client.
@@ -23,7 +23,7 @@ class OpenSearchDBConfig(BaseVectorDbConfig):
         :param opensearch_url: URL of the OpenSearch domain
         :type opensearch_url: str, Eg, "http://localhost:9200"
         :param http_auth: Tuple of username and password
-        :type http_auth: Tuple[str, str], Eg, ("username", "password")
+        :type http_auth: tuple[str, str], Eg, ("username", "password")
         :param vector_dimension: Dimension of  the vector, defaults to 1536 (openai embedding model)
         :type vector_dimension: int, optional
         :param dir: Path to the database directory, where the database is stored, defaults to None

+ 2 - 2
embedchain/config/vectordb/pinecone.py

@@ -1,4 +1,4 @@
-from typing import Dict, Optional
+from typing import Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -12,7 +12,7 @@ class PineconeDBConfig(BaseVectorDbConfig):
         dir: Optional[str] = None,
         vector_dimension: int = 1536,
         metric: Optional[str] = "cosine",
-        **extra_params: Dict[str, any],
+        **extra_params: dict[str, any],
     ):
         self.metric = metric
         self.vector_dimension = vector_dimension

+ 6 - 6
embedchain/config/vectordb/qdrant.py

@@ -1,4 +1,4 @@
-from typing import Dict, Optional
+from typing import Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -15,10 +15,10 @@ class QdrantDBConfig(BaseVectorDbConfig):
         self,
         collection_name: Optional[str] = None,
         dir: Optional[str] = None,
-        hnsw_config: Optional[Dict[str, any]] = None,
-        quantization_config: Optional[Dict[str, any]] = None,
+        hnsw_config: Optional[dict[str, any]] = None,
+        quantization_config: Optional[dict[str, any]] = None,
         on_disk: Optional[bool] = None,
-        **extra_params: Dict[str, any],
+        **extra_params: dict[str, any],
     ):
         """
         Initializes a configuration class instance for a qdrant client.
@@ -28,9 +28,9 @@ class QdrantDBConfig(BaseVectorDbConfig):
         :param dir: Path to the database directory, where the database is stored, defaults to None
         :type dir: Optional[str], optional
         :param hnsw_config: Params for HNSW index
-        :type hnsw_config: Optional[Dict[str, any]], defaults to None
+        :type hnsw_config: Optional[dict[str, any]], defaults to None
         :param quantization_config: Params for quantization, if None - quantization will be disabled
-        :type quantization_config: Optional[Dict[str, any]], defaults to None
+        :type quantization_config: Optional[dict[str, any]], defaults to None
         :param on_disk: If true - point`s payload will not be stored in memory.
                 It will be read from the disk every time it is requested.
                 This setting saves RAM by (slightly) increasing the response time.

+ 2 - 2
embedchain/config/vectordb/weaviate.py

@@ -1,4 +1,4 @@
-from typing import Dict, Optional
+from typing import Optional
 
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -10,7 +10,7 @@ class WeaviateDBConfig(BaseVectorDbConfig):
         self,
         collection_name: Optional[str] = None,
         dir: Optional[str] = None,
-        **extra_params: Dict[str, any],
+        **extra_params: dict[str, any],
     ):
         self.extra_params = extra_params
         super().__init__(collection_name=collection_name, dir=dir)

+ 25 - 25
embedchain/embedchain.py

@@ -2,7 +2,7 @@ import hashlib
 import json
 import logging
 import sqlite3
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
 
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
@@ -136,12 +136,12 @@ class EmbedChain(JSONSerializable):
         self,
         source: Any,
         data_type: Optional[DataType] = None,
-        metadata: Optional[Dict[str, Any]] = None,
+        metadata: Optional[dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
         dry_run=False,
         loader: Optional[BaseLoader] = None,
         chunker: Optional[BaseChunker] = None,
-        **kwargs: Optional[Dict[str, Any]],
+        **kwargs: Optional[dict[str, Any]],
     ):
         """
         Adds the data from the given URL to the vector db.
@@ -154,7 +154,7 @@ class EmbedChain(JSONSerializable):
         defaults to None
         :type data_type: Optional[DataType], optional
         :param metadata: Metadata associated with the data source., defaults to None
-        :type metadata: Optional[Dict[str, Any]], optional
+        :type metadata: Optional[dict[str, Any]], optional
         :param config: The `AddConfig` instance to use as configuration options., defaults to None
         :type config: Optional[AddConfig], optional
         :raises ValueError: Invalid data type
@@ -243,9 +243,9 @@ class EmbedChain(JSONSerializable):
         self,
         source: Any,
         data_type: Optional[DataType] = None,
-        metadata: Optional[Dict[str, Any]] = None,
+        metadata: Optional[dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
-        **kwargs: Optional[Dict[str, Any]],
+        **kwargs: Optional[dict[str, Any]],
     ):
         """
         Adds the data from the given URL to the vector db.
@@ -261,7 +261,7 @@ class EmbedChain(JSONSerializable):
         defaults to None
         :type data_type: Optional[DataType], optional
         :param metadata: Metadata associated with the data source., defaults to None
-        :type metadata: Optional[Dict[str, Any]], optional
+        :type metadata: Optional[dict[str, Any]], optional
         :param config: The `AddConfig` instance to use as configuration options., defaults to None
         :type config: Optional[AddConfig], optional
         :raises ValueError: Invalid data type
@@ -342,11 +342,11 @@ class EmbedChain(JSONSerializable):
         loader: BaseLoader,
         chunker: BaseChunker,
         src: Any,
-        metadata: Optional[Dict[str, Any]] = None,
+        metadata: Optional[dict[str, Any]] = None,
         source_hash: Optional[str] = None,
         add_config: Optional[AddConfig] = None,
         dry_run=False,
-        **kwargs: Optional[Dict[str, Any]],
+        **kwargs: Optional[dict[str, Any]],
     ):
         """
         Loads the data from the given URL, chunks it, and adds it to database.
@@ -359,7 +359,7 @@ class EmbedChain(JSONSerializable):
         :param source_hash: Hexadecimal hash of the source.
         :param dry_run: Optional. A dry run returns chunks and doesn't update DB.
         :type dry_run: bool, defaults to False
-        :return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
+        :return: (list) documents (embedded text), (list) metadata, (list) ids, (int) number of chunks
         """
         existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)
         app_id = self.config.id if self.config is not None else None
@@ -464,8 +464,8 @@ class EmbedChain(JSONSerializable):
         config: Optional[BaseLlmConfig] = None,
         where=None,
         citations: bool = False,
-        **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, str, str]], List[str]]:
+        **kwargs: Optional[dict[str, Any]],
+    ) -> Union[list[tuple[str, str, str]], list[str]]:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query
@@ -479,7 +479,7 @@ class EmbedChain(JSONSerializable):
         :param citations: A boolean to indicate if db should fetch citation source
         :type citations: bool
         :return: List of contents of the document that matched your query
-        :rtype: List[str]
+        :rtype: list[str]
         """
         query_config = config or self.llm.config
         if where is not None:
@@ -507,10 +507,10 @@ class EmbedChain(JSONSerializable):
         input_query: str,
         config: BaseLlmConfig = None,
         dry_run=False,
-        where: Optional[Dict] = None,
+        where: Optional[dict] = None,
         citations: bool = False,
-        **kwargs: Dict[str, Any],
-    ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
+        **kwargs: dict[str, Any],
+    ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -525,13 +525,13 @@ class EmbedChain(JSONSerializable):
         the LLM. The purpose is to test the prompt, not the response., defaults to False
         :type dry_run: bool, optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
-        :type where: Optional[Dict[str, str]], optional
+        :type where: Optional[dict[str, str]], optional
         :param kwargs: To read more params for the query function. Ex. we use citations boolean
         param to return context along with the answer
-        :type kwargs: Dict[str, Any]
+        :type kwargs: dict[str, Any]
         :return: The answer to the query, with citations if the citation flag is True
         or the dry run result
-        :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
+        :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
         """
         contexts = self._retrieve_from_database(
             input_query=input_query, config=config, where=where, citations=citations, **kwargs
@@ -572,10 +572,10 @@ class EmbedChain(JSONSerializable):
         config: Optional[BaseLlmConfig] = None,
         dry_run=False,
         session_id: str = "default",
-        where: Optional[Dict[str, str]] = None,
+        where: Optional[dict[str, str]] = None,
         citations: bool = False,
-        **kwargs: Dict[str, Any],
-    ) -> Union[Tuple[str, List[Tuple[str, Dict]]], str]:
+        **kwargs: dict[str, Any],
+    ) -> Union[tuple[str, list[tuple[str, dict]]], str]:
         """
         Queries the vector database on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -594,13 +594,13 @@ class EmbedChain(JSONSerializable):
         :param session_id: The session id to use for chat history, defaults to 'default'.
         :type session_id: Optional[str], optional
         :param where: A dictionary of key-value pairs to filter the database results., defaults to None
-        :type where: Optional[Dict[str, str]], optional
+        :type where: Optional[dict[str, str]], optional
         :param kwargs: To read more params for the query function. Ex. we use citations boolean
         param to return context along with the answer
-        :type kwargs: Dict[str, Any]
+        :type kwargs: dict[str, Any]
         :return: The answer to the query, with citations if the citation flag is True
         or the dry run result
-        :rtype: str, if citations is False, otherwise Tuple[str,List[Tuple[str,str,str]]]
+        :rtype: str, if citations is False, otherwise tuple[str, list[tuple[str,str,str]]]
         """
         contexts = self._retrieve_from_database(
             input_query=input_query, config=config, where=where, citations=citations, **kwargs

+ 2 - 1
embedchain/embedder/base.py

@@ -1,4 +1,5 @@
-from typing import Any, Callable, Optional
+from collections.abc import Callable
+from typing import Any, Optional
 
 from embedchain.config.embedder.base import BaseEmbedderConfig
 

+ 2 - 2
embedchain/helpers/callbacks.py

@@ -1,5 +1,5 @@
 import queue
-from typing import Any, Dict, List, Union
+from typing import Any, Union
 
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 from langchain.schema import LLMResult
@@ -29,7 +29,7 @@ class StreamingStdOutCallbackHandlerYield(StreamingStdOutCallbackHandler):
         super().__init__()
         self.q = q
 
-    def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
+    def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any) -> None:
         """Run when LLM starts running."""
         with self.q.mutex:
             self.q.queue.clear()

+ 3 - 3
embedchain/helpers/json_serializable.py

@@ -1,7 +1,7 @@
 import json
 import logging
 from string import Template
-from typing import Any, Dict, Type, TypeVar, Union
+from typing import Any, Type, TypeVar, Union
 
 T = TypeVar("T", bound="JSONSerializable")
 
@@ -84,7 +84,7 @@ class JSONSerializable:
             return cls()
 
     @staticmethod
-    def _auto_encoder(obj: Any) -> Union[Dict[str, Any], None]:
+    def _auto_encoder(obj: Any) -> Union[dict[str, Any], None]:
         """
         Automatically encode an object for JSON serialization.
 
@@ -126,7 +126,7 @@ class JSONSerializable:
         raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
 
     @classmethod
-    def _auto_decoder(cls, dct: Dict[str, Any]) -> Any:
+    def _auto_decoder(cls, dct: dict[str, Any]) -> Any:
         """
         Automatically decode a dictionary to an object during JSON deserialization.
 

+ 11 - 10
embedchain/llm/base.py

@@ -1,5 +1,6 @@
 import logging
-from typing import Any, Dict, Generator, List, Optional
+from collections.abc import Generator
+from typing import Any, Optional
 
 from langchain.schema import BaseMessage as LCBaseMessage
 
@@ -55,7 +56,7 @@ class BaseLlm(JSONSerializable):
         app_id: str,
         question: str,
         answer: str,
-        metadata: Optional[Dict[str, Any]] = None,
+        metadata: Optional[dict[str, Any]] = None,
         session_id: str = "default",
     ):
         chat_message = ChatMessage()
@@ -64,7 +65,7 @@ class BaseLlm(JSONSerializable):
         self.memory.add(app_id=app_id, chat_message=chat_message, session_id=session_id)
         self.update_history(app_id=app_id, session_id=session_id)
 
-    def generate_prompt(self, input_query: str, contexts: List[str], **kwargs: Dict[str, Any]) -> str:
+    def generate_prompt(self, input_query: str, contexts: list[str], **kwargs: dict[str, Any]) -> str:
         """
         Generates a prompt based on the given query and context, ready to be
         passed to an LLM
@@ -72,7 +73,7 @@ class BaseLlm(JSONSerializable):
         :param input_query: The query to use.
         :type input_query: str
         :param contexts: List of similar documents to the query used as context.
-        :type contexts: List[str]
+        :type contexts: list[str]
         :return: The prompt
         :rtype: str
         """
@@ -170,7 +171,7 @@ class BaseLlm(JSONSerializable):
             yield chunk
         logging.info(f"Answer: {streamed_answer}")
 
-    def query(self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False):
+    def query(self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False):
         """
         Queries the vector database based on the given input query.
         Gets relevant doc based on the query and then passes it to an
@@ -179,7 +180,7 @@ class BaseLlm(JSONSerializable):
         :param input_query: The query to use.
         :type input_query: str
         :param contexts: Embeddings retrieved from the database to be used as context.
-        :type contexts: List[str]
+        :type contexts: list[str]
         :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
         To persistently use a config, declare it during app init., defaults to None
         :type config: Optional[BaseLlmConfig], optional
@@ -223,7 +224,7 @@ class BaseLlm(JSONSerializable):
                 self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
 
     def chat(
-        self, input_query: str, contexts: List[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
+        self, input_query: str, contexts: list[str], config: BaseLlmConfig = None, dry_run=False, session_id: str = None
     ):
         """
         Queries the vector database on the given input query.
@@ -235,7 +236,7 @@ class BaseLlm(JSONSerializable):
         :param input_query: The query to use.
         :type input_query: str
         :param contexts: Embeddings retrieved from the database to be used as context.
-        :type contexts: List[str]
+        :type contexts: list[str]
         :param config: The `BaseLlmConfig` instance to use as configuration options. This is used for one method call.
         To persistently use a config, declare it during app init., defaults to None
         :type config: Optional[BaseLlmConfig], optional
@@ -281,7 +282,7 @@ class BaseLlm(JSONSerializable):
                 self.config: BaseLlmConfig = BaseLlmConfig.deserialize(prev_config)
 
     @staticmethod
-    def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> List[LCBaseMessage]:
+    def _get_messages(prompt: str, system_prompt: Optional[str] = None) -> list[LCBaseMessage]:
         """
         Construct a list of langchain messages
 
@@ -290,7 +291,7 @@ class BaseLlm(JSONSerializable):
         :param system_prompt: System prompt, defaults to None
         :type system_prompt: Optional[str], optional
         :return: List of messages
-        :rtype: List[BaseMessage]
+        :rtype: list[BaseMessage]
         """
         from langchain.schema import HumanMessage, SystemMessage
 

+ 2 - 1
embedchain/llm/google.py

@@ -1,7 +1,8 @@
 import importlib
 import logging
 import os
-from typing import Any, Generator, Optional, Union
+from collections.abc import Generator
+from typing import Any, Optional, Union
 
 import google.generativeai as genai
 

+ 2 - 1
embedchain/llm/gpt4all.py

@@ -1,6 +1,7 @@
 import os
+from collections.abc import Iterable
 from pathlib import Path
-from typing import Iterable, Optional, Union
+from typing import Optional, Union
 
 from langchain.callbacks.stdout import StdOutCallbackHandler
 from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

+ 2 - 1
embedchain/llm/ollama.py

@@ -1,4 +1,5 @@
-from typing import Iterable, Optional, Union
+from collections.abc import Iterable
+from typing import Optional, Union
 
 from langchain.callbacks.manager import CallbackManager
 from langchain.callbacks.stdout import StdOutCallbackHandler

+ 2 - 2
embedchain/llm/openai.py

@@ -1,6 +1,6 @@
 import json
 import os
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 from langchain.chat_models import ChatOpenAI
 from langchain.schema import AIMessage, HumanMessage, SystemMessage
@@ -12,7 +12,7 @@ from embedchain.llm.base import BaseLlm
 
 @register_deserializable
 class OpenAILlm(BaseLlm):
-    def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[Dict[str, Any]] = None):
+    def __init__(self, config: Optional[BaseLlmConfig] = None, functions: Optional[dict[str, Any]] = None):
         self.functions = functions
         super().__init__(config=config)
 

+ 2 - 2
embedchain/loaders/directory_loader.py

@@ -1,7 +1,7 @@
 import hashlib
 import logging
 from pathlib import Path
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 from embedchain.config import AddConfig
 from embedchain.data_formatter.data_formatter import DataFormatter
@@ -15,7 +15,7 @@ from embedchain.utils.misc import detect_datatype
 class DirectoryLoader(BaseLoader):
     """Load data from a directory."""
 
-    def __init__(self, config: Optional[Dict[str, Any]] = None):
+    def __init__(self, config: Optional[dict[str, Any]] = None):
         super().__init__()
         config = config or {}
         self.recursive = config.get("recursive", True)

+ 2 - 2
embedchain/loaders/discourse.py

@@ -1,7 +1,7 @@
 import hashlib
 import logging
 import time
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 import requests
 
@@ -10,7 +10,7 @@ from embedchain.utils.misc import clean_string
 
 
 class DiscourseLoader(BaseLoader):
-    def __init__(self, config: Optional[Dict[str, Any]] = None):
+    def __init__(self, config: Optional[dict[str, Any]] = None):
         super().__init__()
         if not config:
             raise ValueError(

+ 1 - 2
embedchain/loaders/dropbox.py

@@ -1,6 +1,5 @@
 import hashlib
 import os
-from typing import List
 
 from dropbox.files import FileMetadata
 
@@ -29,7 +28,7 @@ class DropboxLoader(BaseLoader):
         except exceptions.AuthError as ex:
             raise ValueError("Invalid Dropbox access token. Please verify your token and try again.") from ex
 
-    def _download_folder(self, path: str, local_root: str) -> List[FileMetadata]:
+    def _download_folder(self, path: str, local_root: str) -> list[FileMetadata]:
         """Download a folder from Dropbox and save it preserving the directory structure."""
         entries = self.dbx.files_list_folder(path).entries
         for entry in entries:

+ 2 - 2
embedchain/loaders/github.py

@@ -4,7 +4,7 @@ import logging
 import os
 import re
 import shlex
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 from tqdm import tqdm
 
@@ -20,7 +20,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
 class GithubLoader(BaseLoader):
     """Load data from GitHub search query."""
 
-    def __init__(self, config: Optional[Dict[str, Any]] = None):
+    def __init__(self, config: Optional[dict[str, Any]] = None):
         super().__init__()
         if not config:
             raise ValueError(

+ 5 - 5
embedchain/loaders/gmail.py

@@ -5,7 +5,7 @@ import os
 from email import message_from_bytes
 from email.utils import parsedate_to_datetime
 from textwrap import dedent
-from typing import Dict, List, Optional
+from typing import Optional
 
 from bs4 import BeautifulSoup
 
@@ -57,7 +57,7 @@ class GmailReader:
                 token.write(creds.to_json())
         return creds
 
-    def load_emails(self) -> List[Dict]:
+    def load_emails(self) -> list[dict]:
         response = self.service.users().messages().list(userId="me", q=self.query).execute()
         messages = response.get("messages", [])
 
@@ -67,7 +67,7 @@ class GmailReader:
         raw_message = self.service.users().messages().get(userId="me", id=message_id, format="raw").execute()
         return base64.urlsafe_b64decode(raw_message["raw"])
 
-    def _parse_email(self, raw_email) -> Dict:
+    def _parse_email(self, raw_email) -> dict:
         mime_msg = message_from_bytes(raw_email)
         return {
             "subject": self._get_header(mime_msg, "Subject"),
@@ -124,7 +124,7 @@ class GmailLoader(BaseLoader):
         return {"doc_id": self._generate_doc_id(query, data), "data": data}
 
     @staticmethod
-    def _process_email(email: Dict) -> str:
+    def _process_email(email: dict) -> str:
         content = BeautifulSoup(email["body"], "html.parser").get_text()
         content = clean_string(content)
         return dedent(
@@ -137,6 +137,6 @@ class GmailLoader(BaseLoader):
         )
 
     @staticmethod
-    def _generate_doc_id(query: str, data: List[Dict]) -> str:
+    def _generate_doc_id(query: str, data: list[dict]) -> str:
         content_strings = [email["content"] for email in data]
         return hashlib.sha256((query + ", ".join(content_strings)).encode()).hexdigest()

+ 4 - 4
embedchain/loaders/json.py

@@ -2,7 +2,7 @@ import hashlib
 import json
 import os
 import re
-from typing import Dict, List, Union
+from typing import Union
 
 import requests
 
@@ -16,14 +16,14 @@ class JSONReader:
         pass
 
     @staticmethod
-    def load_data(json_data: Union[Dict, str]) -> List[str]:
+    def load_data(json_data: Union[dict, str]) -> list[str]:
         """Load data from a JSON structure.
 
         Args:
-            json_data (Union[Dict, str]): The JSON data to load.
+            json_data (Union[dict, str]): The JSON data to load.
 
         Returns:
-            List[str]: A list of strings representing the leaf nodes of the JSON.
+            list[str]: A list of strings representing the leaf nodes of the JSON.
         """
         if isinstance(json_data, str):
             json_data = json.loads(json_data)

+ 3 - 3
embedchain/loaders/mysql.py

@@ -1,13 +1,13 @@
 import hashlib
 import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils.misc import clean_string
 
 
 class MySQLLoader(BaseLoader):
-    def __init__(self, config: Optional[Dict[str, Any]]):
+    def __init__(self, config: Optional[dict[str, Any]]):
         super().__init__()
         if not config:
             raise ValueError(
@@ -20,7 +20,7 @@ class MySQLLoader(BaseLoader):
         self.cursor = None
         self._setup_loader(config=config)
 
-    def _setup_loader(self, config: Dict[str, Any]):
+    def _setup_loader(self, config: dict[str, Any]):
         try:
             import mysql.connector as sqlconnector
         except ImportError as e:

+ 3 - 3
embedchain/loaders/notion.py

@@ -1,7 +1,7 @@
 import hashlib
 import logging
 import os
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
 
 import requests
 
@@ -15,7 +15,7 @@ class NotionDocument:
     A simple Document class to hold the text and additional information of a page.
     """
 
-    def __init__(self, text: str, extra_info: Dict[str, Any]):
+    def __init__(self, text: str, extra_info: dict[str, Any]):
         self.text = text
         self.extra_info = extra_info
 
@@ -82,7 +82,7 @@ class NotionPageLoader:
         result_lines = "\n".join(result_lines_arr)
         return result_lines
 
-    def load_data(self, page_ids: List[str]) -> List[NotionDocument]:
+    def load_data(self, page_ids: list[str]) -> list[NotionDocument]:
         """Load data from the given list of page IDs."""
         docs = []
         for page_id in page_ids:

+ 3 - 3
embedchain/loaders/postgres.py

@@ -1,12 +1,12 @@
 import hashlib
 import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 from embedchain.loaders.base_loader import BaseLoader
 
 
 class PostgresLoader(BaseLoader):
-    def __init__(self, config: Optional[Dict[str, Any]] = None):
+    def __init__(self, config: Optional[dict[str, Any]] = None):
         super().__init__()
         if not config:
             raise ValueError(f"Must provide the valid config. Received: {config}")
@@ -15,7 +15,7 @@ class PostgresLoader(BaseLoader):
         self.cursor = None
         self._setup_loader(config=config)
 
-    def _setup_loader(self, config: Dict[str, Any]):
+    def _setup_loader(self, config: dict[str, Any]):
         try:
             import psycopg
         except ImportError as e:

+ 3 - 3
embedchain/loaders/slack.py

@@ -2,7 +2,7 @@ import hashlib
 import logging
 import os
 import ssl
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 import certifi
 
@@ -13,7 +13,7 @@ SLACK_API_BASE_URL = "https://www.slack.com/api/"
 
 
 class SlackLoader(BaseLoader):
-    def __init__(self, config: Optional[Dict[str, Any]] = None):
+    def __init__(self, config: Optional[dict[str, Any]] = None):
         super().__init__()
 
         self.config = config if config else {}
@@ -24,7 +24,7 @@ class SlackLoader(BaseLoader):
         self.client = None
         self._setup_loader(self.config)
 
-    def _setup_loader(self, config: Dict[str, Any]):
+    def _setup_loader(self, config: dict[str, Any]):
         try:
             from slack_sdk import WebClient
         except ImportError as e:

+ 3 - 3
embedchain/memory/base.py

@@ -2,7 +2,7 @@ import json
 import logging
 import sqlite3
 import uuid
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
 
 from embedchain.constants import SQLITE_PATH
 from embedchain.memory.message import ChatMessage
@@ -67,7 +67,7 @@ class ChatHistory:
         self.cursor.execute(DELETE_CHAT_HISTORY_QUERY, (app_id, session_id))
         self.connection.commit()
 
-    def get(self, app_id, session_id, num_rounds=10, display_format=False) -> List[ChatMessage]:
+    def get(self, app_id, session_id, num_rounds=10, display_format=False) -> list[ChatMessage]:
         """
         Get the most recent num_rounds rounds of conversations
         between human and AI, for a given app_id.
@@ -114,7 +114,7 @@ class ChatHistory:
         return count
 
     @staticmethod
-    def _serialize_json(metadata: Dict[str, Any]):
+    def _serialize_json(metadata: dict[str, Any]):
         return json.dumps(metadata)
 
     @staticmethod

+ 3 - 3
embedchain/memory/message.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 from embedchain.helpers.json_serializable import JSONSerializable
 
@@ -18,9 +18,9 @@ class BaseMessage(JSONSerializable):
     created_by: str
 
     # Any additional info.
-    metadata: Dict[str, Any]
+    metadata: dict[str, Any]
 
-    def __init__(self, content: str, created_by: str, metadata: Optional[Dict[str, Any]] = None) -> None:
+    def __init__(self, content: str, created_by: str, metadata: Optional[dict[str, Any]] = None) -> None:
         super().__init__()
         self.content = content
         self.created_by = created_by

+ 5 - 5
embedchain/memory/utils.py

@@ -1,16 +1,16 @@
-from typing import Any, Dict, Optional
+from typing import Any, Optional
 
 
-def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+def merge_metadata_dict(left: Optional[dict[str, Any]], right: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
     """
     Merge the metadatas of two BaseMessage types.
 
     Args:
-        left (Dict[str, Any]): metadata of human message
-        right (Dict[str, Any]): metadata of AI message
+        left (dict[str, Any]): metadata of human message
+        right (dict[str, Any]): metadata of AI message
 
     Returns:
-        Dict[str, Any]: combined metadata dict with dedup
+        dict[str, Any]: combined metadata dict with dedup
         to be saved in db.
     """
     if not left and not right:

+ 22 - 22
embedchain/vectordb/chroma.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
 
 from chromadb import Collection, QueryResult
 from langchain.docstore.document import Document
@@ -76,7 +76,7 @@ class ChromaDB(BaseVectorDB):
         return self.client
 
     @staticmethod
-    def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]:
+    def _generate_where_clause(where: dict[str, any]) -> dict[str, any]:
         # If only one filter is supplied, return it as is
         # (no need to wrap in $and based on chroma docs)
         if len(where.keys()) <= 1:
@@ -105,18 +105,18 @@ class ChromaDB(BaseVectorDB):
         )
         return self.collection
 
-    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+    def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
         """
         Get existing doc ids present in vector database
 
         :param ids: list of doc ids to check for existence
-        :type ids: List[str]
+        :type ids: list[str]
         :param where: Optional. to filter data
-        :type where: Dict[str, Any]
+        :type where: dict[str, Any]
         :param limit: Optional. maximum number of documents
         :type limit: Optional[int]
         :return: Existing documents.
-        :rtype: List[str]
+        :rtype: list[str]
         """
         args = {}
         if ids:
@@ -129,23 +129,23 @@ class ChromaDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: List[List[float]],
-        documents: List[str],
-        metadatas: List[object],
-        ids: List[str],
-        **kwargs: Optional[Dict[str, Any]],
+        embeddings: list[list[float]],
+        documents: list[str],
+        metadatas: list[object],
+        ids: list[str],
+        **kwargs: Optional[dict[str, Any]],
     ) -> Any:
         """
         Add vectors to chroma database
 
         :param embeddings: list of embeddings to add
-        :type embeddings: List[List[str]]
+        :type embeddings: list[list[str]]
         :param documents: Documents
-        :type documents: List[str]
+        :type documents: list[str]
         :param metadatas: Metadatas
-        :type metadatas: List[object]
+        :type metadatas: list[object]
         :param ids: ids
-        :type ids: List[str]
+        :type ids: list[str]
         """
         size = len(documents)
         if len(documents) != size or len(metadatas) != size or len(ids) != size:
@@ -182,27 +182,27 @@ class ChromaDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: List[str],
+        input_query: list[str],
         n_results: int,
-        where: Dict[str, any],
+        where: dict[str, any],
         citations: bool = False,
-        **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, Dict]], List[str]]:
+        **kwargs: Optional[dict[str, Any]],
+    ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         Query contents from vector database based on vector similarity
 
         :param input_query: list of query string
-        :type input_query: List[str]
+        :type input_query: list[str]
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: to filter data
-        :type where: Dict[str, Any]
+        :type where: dict[str, Any]
         :param citations: we use citations boolean param to return context along with the answer.
         :type citations: bool, default is False.
         :raises InvalidDimensionException: Dimensions do not match.
         :return: The content of the document that matched your query,
         along with url of the source and doc_id (if citations flag is true)
-        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
+        :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
         try:
             result = self.collection.query(

+ 20 - 20
embedchain/vectordb/elasticsearch.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
 
 try:
     from elasticsearch import Elasticsearch
@@ -84,14 +84,14 @@ class ElasticsearchDB(BaseVectorDB):
     def _get_or_create_collection(self, name):
         """Note: nothing to return here. Discuss later"""
 
-    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+    def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
         """
         Get existing doc ids present in vector database
 
         :param ids: _list of doc ids to check for existence
-        :type ids: List[str]
+        :type ids: list[str]
         :param where: to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :return: ids
         :rtype: Set[str]
         """
@@ -110,22 +110,22 @@ class ElasticsearchDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: List[List[float]],
-        documents: List[str],
-        metadatas: List[object],
-        ids: List[str],
-        **kwargs: Optional[Dict[str, any]],
+        embeddings: list[list[float]],
+        documents: list[str],
+        metadatas: list[object],
+        ids: list[str],
+        **kwargs: Optional[dict[str, any]],
     ) -> Any:
         """
         add data in vector database
         :param embeddings: list of embeddings to add
-        :type embeddings: List[List[str]]
+        :type embeddings: list[list[str]]
         :param documents: list of texts to add
-        :type documents: List[str]
+        :type documents: list[str]
         :param metadatas: list of metadata associated with docs
-        :type metadatas: List[object]
+        :type metadatas: list[object]
         :param ids: ids of docs
-        :type ids: List[str]
+        :type ids: list[str]
         """
 
         embeddings = self.embedder.embedding_fn(documents)
@@ -154,27 +154,27 @@ class ElasticsearchDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: List[str],
+        input_query: list[str],
         n_results: int,
-        where: Dict[str, any],
+        where: dict[str, any],
         citations: bool = False,
-        **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, Dict]], List[str]]:
+        **kwargs: Optional[dict[str, Any]],
+    ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity
 
         :param input_query: list of query string
-        :type input_query: List[str]
+        :type input_query: list[str]
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :return: The context of the document that matched your query, url of the source, doc_id
         :param citations: we use citations boolean param to return context along with the answer.
         :type citations: bool, default is False.
         :return: The content of the document that matched your query,
         along with url of the source and doc_id (if citations flag is true)
-        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
+        :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
         input_query_vector = self.embedder.embedding_fn(input_query)
         query_vector = input_query_vector[0]

+ 22 - 22
embedchain/vectordb/opensearch.py

@@ -1,6 +1,6 @@
 import logging
 import time
-from typing import Any, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Optional, Union
 
 from tqdm import tqdm
 
@@ -78,17 +78,17 @@ class OpenSearchDB(BaseVectorDB):
         """Note: nothing to return here. Discuss later"""
 
     def get(
-        self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None
-    ) -> Set[str]:
+        self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None
+    ) -> set[str]:
         """
         Get existing doc ids present in vector database
 
         :param ids: _list of doc ids to check for existence
-        :type ids: List[str]
+        :type ids: list[str]
         :param where: to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :return: ids
-        :type: Set[str]
+        :type: set[str]
         """
         query = {}
         if ids:
@@ -116,19 +116,19 @@ class OpenSearchDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: List[List[str]],
-        documents: List[str],
-        metadatas: List[object],
-        ids: List[str],
-        **kwargs: Optional[Dict[str, any]],
+        embeddings: list[list[str]],
+        documents: list[str],
+        metadatas: list[object],
+        ids: list[str],
+        **kwargs: Optional[dict[str, any]],
     ):
         """Add data in vector database.
 
         Args:
-            embeddings (List[List[str]]): List of embeddings to add.
-            documents (List[str]): List of texts to add.
-            metadatas (List[object]): List of metadata associated with docs.
-            ids (List[str]): IDs of docs.
+            embeddings (list[list[str]]): list of embeddings to add.
+            documents (list[str]): list of texts to add.
+            metadatas (list[object]): list of metadata associated with docs.
+            ids (list[str]): IDs of docs.
         """
         for batch_start in tqdm(range(0, len(documents), self.BATCH_SIZE), desc="Inserting batches in opensearch"):
             batch_end = batch_start + self.BATCH_SIZE
@@ -156,26 +156,26 @@ class OpenSearchDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: List[str],
+        input_query: list[str],
         n_results: int,
-        where: Dict[str, any],
+        where: dict[str, any],
         citations: bool = False,
-        **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, Dict]], List[str]]:
+        **kwargs: Optional[dict[str, Any]],
+    ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity
 
         :param input_query: list of query string
-        :type input_query: List[str]
+        :type input_query: list[str]
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :param citations: we use citations boolean param to return context along with the answer.
         :type citations: bool, default is False.
         :return: The content of the document that matched your query,
         along with url of the source and doc_id (if citations flag is true)
-        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
+        :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
         embeddings = OpenAIEmbeddings()
         docsearch = OpenSearchVectorSearch(

+ 19 - 19
embedchain/vectordb/pinecone.py

@@ -1,5 +1,5 @@
 import os
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Optional, Union
 
 try:
     import pinecone
@@ -67,14 +67,14 @@ class PineconeDB(BaseVectorDB):
             )
         return pinecone.Index(self.index_name)
 
-    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+    def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
         """
         Get existing doc ids present in vector database
 
         :param ids: _list of doc ids to check for existence
-        :type ids: List[str]
+        :type ids: list[str]
         :param where: to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :return: ids
         :rtype: Set[str]
         """
@@ -88,20 +88,20 @@ class PineconeDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: List[List[float]],
-        documents: List[str],
-        metadatas: List[object],
-        ids: List[str],
-        **kwargs: Optional[Dict[str, any]],
+        embeddings: list[list[float]],
+        documents: list[str],
+        metadatas: list[object],
+        ids: list[str],
+        **kwargs: Optional[dict[str, any]],
     ):
         """add data in vector database
 
         :param documents: list of texts to add
-        :type documents: List[str]
+        :type documents: list[str]
         :param metadatas: list of metadata associated with docs
-        :type metadatas: List[object]
+        :type metadatas: list[object]
         :param ids: ids of docs
-        :type ids: List[str]
+        :type ids: list[str]
         """
         docs = []
         print("Adding documents to Pinecone...")
@@ -120,25 +120,25 @@ class PineconeDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: List[str],
+        input_query: list[str],
         n_results: int,
-        where: Dict[str, any],
+        where: dict[str, any],
         citations: bool = False,
-        **kwargs: Optional[Dict[str, any]],
-    ) -> Union[List[Tuple[str, Dict]], List[str]]:
+        **kwargs: Optional[dict[str, any]],
+    ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
-        :type input_query: List[str]
+        :type input_query: list[str]
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :param citations: we use citations boolean param to return context along with the answer.
         :type citations: bool, default is False.
         :return: The content of the document that matched your query,
         along with url of the source and doc_id (if citations flag is true)
-        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
+        :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
         query_vector = self.embedder.embedding_fn([input_query])[0]
         data = self.client.query(vector=query_vector, filter=where, top_k=n_results, include_metadata=True, **kwargs)

+ 20 - 20
embedchain/vectordb/qdrant.py

@@ -1,7 +1,7 @@
 import copy
 import os
 import uuid
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
 
 try:
     from qdrant_client import QdrantClient
@@ -69,14 +69,14 @@ class QdrantDB(BaseVectorDB):
     def _get_or_create_collection(self):
         return f"{self.config.collection_name}-{self.embedder.vector_dimension}".lower().replace("_", "-")
 
-    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+    def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
         """
         Get existing doc ids present in vector database
 
         :param ids: _list of doc ids to check for existence
-        :type ids: List[str]
+        :type ids: list[str]
         :param where: to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :param limit: The number of entries to be fetched
         :type limit: Optional int, defaults to None
         :return: All the existing IDs
@@ -122,21 +122,21 @@ class QdrantDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: List[List[float]],
-        documents: List[str],
-        metadatas: List[object],
-        ids: List[str],
-        **kwargs: Optional[Dict[str, any]],
+        embeddings: list[list[float]],
+        documents: list[str],
+        metadatas: list[object],
+        ids: list[str],
+        **kwargs: Optional[dict[str, any]],
     ):
         """add data in vector database
         :param embeddings: list of embeddings for the corresponding documents to be added
-        :type documents: List[List[float]]
+        :type documents: list[list[float]]
         :param documents: list of texts to add
-        :type documents: List[str]
+        :type documents: list[str]
         :param metadatas: list of metadata associated with docs
-        :type metadatas: List[object]
+        :type metadatas: list[object]
         :param ids: ids of docs
-        :type ids: List[str]
+        :type ids: list[str]
         """
         embeddings = self.embedder.embedding_fn(documents)
 
@@ -159,25 +159,25 @@ class QdrantDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: List[str],
+        input_query: list[str],
         n_results: int,
-        where: Dict[str, any],
+        where: dict[str, any],
         citations: bool = False,
-        **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, Dict]], List[str]]:
+        **kwargs: Optional[dict[str, Any]],
+    ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
-        :type input_query: List[str]
+        :type input_query: list[str]
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :param citations: we use citations boolean param to return context along with the answer.
         :type citations: bool, default is False.
         :return: The content of the document that matched your query,
         along with url of the source and doc_id (if citations flag is true)
-        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
+        :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
         query_vector = self.embedder.embedding_fn([input_query])[0]
         keys = set(where.keys() if where is not None else set())

+ 20 - 20
embedchain/vectordb/weaviate.py

@@ -1,6 +1,6 @@
 import copy
 import os
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
 
 try:
     import weaviate
@@ -117,13 +117,13 @@ class WeaviateDB(BaseVectorDB):
 
             self.client.schema.create(class_obj)
 
-    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+    def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
         """
         Get existing doc ids present in vector database
         :param ids: _list of doc ids to check for existance
-        :type ids: List[str]
+        :type ids: list[str]
         :param where: to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :return: ids
         :rtype: Set[str]
         """
@@ -153,21 +153,21 @@ class WeaviateDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: List[List[float]],
-        documents: List[str],
-        metadatas: List[object],
-        ids: List[str],
-        **kwargs: Optional[Dict[str, any]],
+        embeddings: list[list[float]],
+        documents: list[str],
+        metadatas: list[object],
+        ids: list[str],
+        **kwargs: Optional[dict[str, any]],
     ):
         """add data in vector database
         :param embeddings: list of embeddings for the corresponding documents to be added
-        :type documents: List[List[float]]
+        :type documents: list[list[float]]
         :param documents: list of texts to add
-        :type documents: List[str]
+        :type documents: list[str]
         :param metadatas: list of metadata associated with docs
-        :type metadatas: List[object]
+        :type metadatas: list[object]
         :param ids: ids of docs
-        :type ids: List[str]
+        :type ids: list[str]
         """
         embeddings = self.embedder.embedding_fn(documents)
         self.client.batch.configure(batch_size=self.BATCH_SIZE, timeout_retries=3)  # Configure batch
@@ -192,25 +192,25 @@ class WeaviateDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: List[str],
+        input_query: list[str],
         n_results: int,
-        where: Dict[str, any],
+        where: dict[str, any],
         citations: bool = False,
-        **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, Dict]], List[str]]:
+        **kwargs: Optional[dict[str, Any]],
+    ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         query contents from vector database based on vector similarity
         :param input_query: list of query string
-        :type input_query: List[str]
+        :type input_query: list[str]
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: Optional. to filter data
-        :type where: Dict[str, any]
+        :type where: dict[str, any]
         :param citations: we use citations boolean param to return context along with the answer.
         :type citations: bool, default is False.
         :return: The content of the document that matched your query,
         along with url of the source and doc_id (if citations flag is true)
-        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
+        :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
         query_vector = self.embedder.embedding_fn([input_query])[0]
         keys = set(where.keys() if where is not None else set())

+ 16 - 16
embedchain/vectordb/zilliz.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Union
 
 from embedchain.config import ZillizDBConfig
 from embedchain.helpers.json_serializable import register_deserializable
@@ -88,14 +88,14 @@ class ZillizVectorDB(BaseVectorDB):
             self.collection.create_index("embeddings", index)
         return self.collection
 
-    def get(self, ids: Optional[List[str]] = None, where: Optional[Dict[str, any]] = None, limit: Optional[int] = None):
+    def get(self, ids: Optional[list[str]] = None, where: Optional[dict[str, any]] = None, limit: Optional[int] = None):
         """
         Get existing doc ids present in vector database
 
         :param ids: list of doc ids to check for existence
-        :type ids: List[str]
+        :type ids: list[str]
         :param where: Optional. to filter data
-        :type where: Dict[str, Any]
+        :type where: dict[str, Any]
         :param limit: Optional. maximum number of documents
         :type limit: Optional[int]
         :return: Existing documents.
@@ -115,11 +115,11 @@ class ZillizVectorDB(BaseVectorDB):
 
     def add(
         self,
-        embeddings: List[List[float]],
-        documents: List[str],
-        metadatas: List[object],
-        ids: List[str],
-        **kwargs: Optional[Dict[str, any]],
+        embeddings: list[list[float]],
+        documents: list[str],
+        metadatas: list[object],
+        ids: list[str],
+        **kwargs: Optional[dict[str, any]],
     ):
         """Add to database"""
         embeddings = self.embedder.embedding_fn(documents)
@@ -134,17 +134,17 @@ class ZillizVectorDB(BaseVectorDB):
 
     def query(
         self,
-        input_query: List[str],
+        input_query: list[str],
         n_results: int,
-        where: Dict[str, any],
+        where: dict[str, any],
         citations: bool = False,
-        **kwargs: Optional[Dict[str, Any]],
-    ) -> Union[List[Tuple[str, Dict]], List[str]]:
+        **kwargs: Optional[dict[str, Any]],
+    ) -> Union[list[tuple[str, dict]], list[str]]:
         """
         Query contents from vector database based on vector similarity
 
         :param input_query: list of query string
-        :type input_query: List[str]
+        :type input_query: list[str]
         :param n_results: no of similar documents to fetch from database
         :type n_results: int
         :param where: to filter data
@@ -154,7 +154,7 @@ class ZillizVectorDB(BaseVectorDB):
         :type citations: bool, default is False.
         :return: The content of the document that matched your query,
         along with url of the source and doc_id (if citations flag is true)
-        :rtype: List[str], if citations=False, otherwise List[Tuple[str, str, str]]
+        :rtype: list[str], if citations=False, otherwise list[tuple[str, str, str]]
         """
 
         if self.collection.is_empty:
@@ -200,7 +200,7 @@ class ZillizVectorDB(BaseVectorDB):
         """
         return self.collection.num_entities
 
-    def reset(self, collection_names: List[str] = None):
+    def reset(self, collection_names: list[str] = None):
         """
         Resets the database. Deletes all embeddings irreversibly.
         """