Kaynağa Gözat

[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)

Sandra Serrano 1 yıl önce
ebeveyn
işleme
2496ed133e
41 değiştirilmiş dosya ile 133 ekleme ve 103 silme
  1. 9 4
      embedchain/app.py
  2. 6 5
      embedchain/chunkers/base_chunker.py
  3. 1 1
      embedchain/client.py
  4. 3 2
      embedchain/config/add_config.py
  5. 6 3
      embedchain/config/cache_config.py
  6. 6 4
      embedchain/config/llm/base.py
  7. 2 2
      embedchain/config/vectordb/qdrant.py
  8. 1 1
      embedchain/config/vectordb/zilliz.py
  9. 2 1
      embedchain/data_formatter/data_formatter.py
  10. 7 9
      embedchain/embedchain.py
  11. 3 3
      embedchain/embedder/base.py
  12. 2 2
      embedchain/embedder/google.py
  13. 1 1
      embedchain/helpers/json_serializable.py
  14. 8 7
      embedchain/llm/base.py
  15. 1 1
      embedchain/llm/google.py
  16. 1 1
      embedchain/llm/huggingface.py
  17. 2 1
      embedchain/llm/ollama.py
  18. 1 1
      embedchain/loaders/base_loader.py
  19. 1 1
      embedchain/loaders/directory_loader.py
  20. 2 1
      embedchain/loaders/docs_site_loader.py
  21. 19 15
      embedchain/loaders/github.py
  22. 2 1
      embedchain/loaders/image.py
  23. 2 1
      embedchain/loaders/json.py
  24. 2 1
      embedchain/loaders/mysql.py
  25. 2 2
      embedchain/loaders/postgres.py
  26. 2 1
      embedchain/loaders/slack.py
  27. 1 1
      embedchain/loaders/unstructured_file.py
  28. 5 4
      embedchain/loaders/web_page.py
  29. 4 2
      embedchain/memory/base.py
  30. 2 2
      embedchain/memory/message.py
  31. 1 1
      embedchain/memory/utils.py
  32. 5 3
      embedchain/store/assistants.py
  33. 2 1
      embedchain/telemetry/posthog.py
  34. 2 3
      embedchain/utils/misc.py
  35. 1 1
      embedchain/vectordb/base.py
  36. 4 2
      embedchain/vectordb/chroma.py
  37. 2 2
      embedchain/vectordb/elasticsearch.py
  38. 1 1
      embedchain/vectordb/opensearch.py
  39. 2 1
      embedchain/vectordb/weaviate.py
  40. 5 6
      embedchain/vectordb/zilliz.py
  41. 2 1
      tests/chunkers/test_text.py

+ 9 - 4
embedchain/app.py

@@ -9,9 +9,14 @@ from typing import Any, Dict, Optional
 import requests
 import requests
 import yaml
 import yaml
 
 
-from embedchain.cache import (Config, ExactMatchEvaluation,
-                              SearchDistanceEvaluation, cache,
-                              gptcache_data_manager, gptcache_pre_function)
+from embedchain.cache import (
+    Config,
+    ExactMatchEvaluation,
+    SearchDistanceEvaluation,
+    cache,
+    gptcache_data_manager,
+    gptcache_pre_function,
+)
 from embedchain.client import Client
 from embedchain.client import Client
 from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
 from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
 from embedchain.constants import SQLITE_PATH
 from embedchain.constants import SQLITE_PATH
@@ -27,7 +32,7 @@ from embedchain.utils.misc import validate_config
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.chroma import ChromaDB
 from embedchain.vectordb.chroma import ChromaDB
 
 
-# Setup the user directory if doesn't exist already
+# Set up the user directory if it doesn't exist already
 Client.setup_dir()
 Client.setup_dir()
 
 
 
 

+ 6 - 5
embedchain/chunkers/base_chunker.py

@@ -17,7 +17,7 @@ class BaseChunker(JSONSerializable):
         """
         """
         Loads data and chunks it.
         Loads data and chunks it.
 
 
-        :param loader: The loader which's `load_data` method is used to create
+        :param loader: The loader whose `load_data` method is used to create
         the raw data.
         the raw data.
         :param src: The data to be handled by the loader. Can be a URL for
         :param src: The data to be handled by the loader. Can be a URL for
         remote sources or local content for local loaders.
         remote sources or local content for local loaders.
@@ -25,7 +25,7 @@ class BaseChunker(JSONSerializable):
         """
         """
         documents = []
         documents = []
         chunk_ids = []
         chunk_ids = []
-        idMap = {}
+        id_map = {}
         min_chunk_size = config.min_chunk_size if config is not None else 1
         min_chunk_size = config.min_chunk_size if config is not None else 1
         logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
         logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
         data_result = loader.load_data(src)
@@ -49,8 +49,8 @@ class BaseChunker(JSONSerializable):
             for chunk in chunks:
             for chunk in chunks:
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
                 chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
-                if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
-                    idMap[chunk_id] = True
+                if id_map.get(chunk_id) is None and len(chunk) >= min_chunk_size:
+                    id_map[chunk_id] = True
                     chunk_ids.append(chunk_id)
                     chunk_ids.append(chunk_id)
                     documents.append(chunk)
                     documents.append(chunk)
                     metadatas.append(meta_data)
                     metadatas.append(meta_data)
@@ -77,5 +77,6 @@ class BaseChunker(JSONSerializable):
 
 
         # TODO: This should be done during initialization. This means it has to be done in the child classes.
         # TODO: This should be done during initialization. This means it has to be done in the child classes.
 
 
-    def get_word_count(self, documents):
+    @staticmethod
+    def get_word_count(documents) -> int:
         return sum([len(document.split(" ")) for document in documents])
         return sum([len(document.split(" ")) for document in documents])

+ 1 - 1
embedchain/client.py

@@ -31,7 +31,7 @@ class Client:
                 )
                 )
 
 
     @classmethod
     @classmethod
-    def setup_dir(self):
+    def setup_dir(cls):
         """
         """
         Loads the user id from the config file if it exists, otherwise generates a new
         Loads the user id from the config file if it exists, otherwise generates a new
         one and saves it to the config file.
         one and saves it to the config file.

+ 3 - 2
embedchain/config/add_config.py

@@ -26,7 +26,7 @@ class ChunkerConfig(BaseConfig):
         if self.min_chunk_size >= self.chunk_size:
         if self.min_chunk_size >= self.chunk_size:
             raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
             raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
         if self.min_chunk_size < self.chunk_overlap:
         if self.min_chunk_size < self.chunk_overlap:
-            logging.warn(
+            logging.warning(
                 f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant."  # noqa:E501
                 f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant."  # noqa:E501
             )
             )
 
 
@@ -35,7 +35,8 @@ class ChunkerConfig(BaseConfig):
         else:
         else:
             self.length_function = length_function if length_function else len
             self.length_function = length_function if length_function else len
 
 
-    def load_func(self, dotpath: str):
+    @staticmethod
+    def load_func(dotpath: str):
         if "." not in dotpath:
         if "." not in dotpath:
             return getattr(builtins, dotpath)
             return getattr(builtins, dotpath)
         else:
         else:

+ 6 - 3
embedchain/config/cache_config.py

@@ -10,12 +10,12 @@ class CacheSimilarityEvalConfig(BaseConfig):
     This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
     This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
     In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
     In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
     put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
     put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
-    `positive` is used to indicate this distance is directly proportional to the similarity of two entites.
-    If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score.
+    `positive` is used to indicate this distance is directly proportional to the similarity of two entities.
+    If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
 
 
     :param max_distance: the bound of maximum distance.
     :param max_distance: the bound of maximum distance.
     :type max_distance: float
     :type max_distance: float
-    :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False.
+    :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
     :type positive: bool
     :type positive: bool
     """
     """
 
 
@@ -29,6 +29,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
         self.max_distance = max_distance
         self.max_distance = max_distance
         self.positive = positive
         self.positive = positive
 
 
+    @staticmethod
     def from_config(config: Optional[Dict[str, Any]]):
     def from_config(config: Optional[Dict[str, Any]]):
         if config is None:
         if config is None:
             return CacheSimilarityEvalConfig()
             return CacheSimilarityEvalConfig()
@@ -63,6 +64,7 @@ class CacheInitConfig(BaseConfig):
         self.similarity_threshold = similarity_threshold
         self.similarity_threshold = similarity_threshold
         self.auto_flush = auto_flush
         self.auto_flush = auto_flush
 
 
+    @staticmethod
     def from_config(config: Optional[Dict[str, Any]]):
     def from_config(config: Optional[Dict[str, Any]]):
         if config is None:
         if config is None:
             return CacheInitConfig()
             return CacheInitConfig()
@@ -83,6 +85,7 @@ class CacheConfig(BaseConfig):
         self.similarity_eval_config = similarity_eval_config
         self.similarity_eval_config = similarity_eval_config
         self.init_config = init_config
         self.init_config = init_config
 
 
+    @staticmethod
     def from_config(config: Optional[Dict[str, Any]]):
     def from_config(config: Optional[Dict[str, Any]]):
         if config is None:
         if config is None:
             return CacheConfig()
             return CacheConfig()

+ 6 - 4
embedchain/config/llm/base.py

@@ -155,24 +155,26 @@ class BaseLlmConfig(BaseConfig):
         self.stream = stream
         self.stream = stream
         self.where = where
         self.where = where
 
 
-    def validate_prompt(self, prompt: Template) -> bool:
+    @staticmethod
+    def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
         """
         """
         validate the prompt
         validate the prompt
 
 
         :param prompt: the prompt to validate
         :param prompt: the prompt to validate
         :type prompt: Template
         :type prompt: Template
         :return: valid (true) or invalid (false)
         :return: valid (true) or invalid (false)
-        :rtype: bool
+        :rtype: Optional[re.Match[str]]
         """
         """
         return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
         return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
 
 
-    def _validate_prompt_history(self, prompt: Template) -> bool:
+    @staticmethod
+    def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
         """
         """
         validate the prompt with history
         validate the prompt with history
 
 
         :param prompt: the prompt to validate
         :param prompt: the prompt to validate
         :type prompt: Template
         :type prompt: Template
         :return: valid (true) or invalid (false)
         :return: valid (true) or invalid (false)
-        :rtype: bool
+        :rtype: Optional[re.Match[str]]
         """
         """
         return re.search(history_re, prompt.template)
         return re.search(history_re, prompt.template)

+ 2 - 2
embedchain/config/vectordb/qdrant.py

@@ -7,8 +7,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 @register_deserializable
 @register_deserializable
 class QdrantDBConfig(BaseVectorDbConfig):
 class QdrantDBConfig(BaseVectorDbConfig):
     """
     """
-    Config to initialize an qdrant client.
-    :param url. qdrant url or list of nodes url to be used for connection
+    Config to initialize a qdrant client.
+    :param: url. qdrant url or list of nodes url to be used for connection
     """
     """
 
 
     def __init__(
     def __init__(

+ 1 - 1
embedchain/config/vectordb/zilliz.py

@@ -26,7 +26,7 @@ class ZillizDBConfig(BaseVectorDbConfig):
         :param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
         :param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
         :type uri: Optional[str], optional
         :type uri: Optional[str], optional
         :param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
         :param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
-        :type port: Optional[str], optional
+        :type token: Optional[str], optional
         """
         """
         self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
         self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
         if not self.uri:
         if not self.uri:

+ 2 - 1
embedchain/data_formatter/data_formatter.py

@@ -34,7 +34,8 @@ class DataFormatter(JSONSerializable):
         self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
         self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
         self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
         self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
 
 
-    def _lazy_load(self, module_path: str):
+    @staticmethod
+    def _lazy_load(module_path: str):
         module_path, class_name = module_path.rsplit(".", 1)
         module_path, class_name = module_path.rsplit(".", 1)
         module = import_module(module_path)
         module = import_module(module_path)
         return getattr(module, class_name)
         return getattr(module, class_name)

+ 7 - 9
embedchain/embedchain.py

@@ -7,9 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
 
 
-from embedchain.cache import (adapt, get_gptcache_session,
-                              gptcache_data_convert,
-                              gptcache_update_cache_callback)
+from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
 from embedchain.config.base_app_config import BaseAppConfig
@@ -19,8 +17,7 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import (DataType, DirectDataType,
-                                         IndirectDataType, SpecialDataType)
+from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.utils.misc import detect_datatype, is_valid_json_string
 from embedchain.utils.misc import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.base import BaseVectorDB
@@ -84,7 +81,7 @@ class EmbedChain(JSONSerializable):
         # Attributes that aren't subclass related.
         # Attributes that aren't subclass related.
         self.user_asks = []
         self.user_asks = []
 
 
-        self.chunker: ChunkerConfig = None
+        self.chunker: Optional[ChunkerConfig] = None
         # Send anonymous telemetry
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -290,7 +287,7 @@ class EmbedChain(JSONSerializable):
             #   Or it's different, then it will be added as a new text.
             #   Or it's different, then it will be added as a new text.
             return None
             return None
         elif chunker.data_type.value in [item.value for item in IndirectDataType]:
         elif chunker.data_type.value in [item.value for item in IndirectDataType]:
-            # These types have a indirect source reference
+            # These types have an indirect source reference
             # As long as the reference is the same, they can be updated.
             # As long as the reference is the same, they can be updated.
             where = {"url": src}
             where = {"url": src}
             if chunker.data_type == DataType.JSON and is_valid_json_string(src):
             if chunker.data_type == DataType.JSON and is_valid_json_string(src):
@@ -442,10 +439,11 @@ class EmbedChain(JSONSerializable):
         )
         )
         count_new_chunks = self.db.count() - chunks_before_addition
         count_new_chunks = self.db.count() - chunks_before_addition
 
 
-        print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
+        print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")
         return list(documents), metadatas, ids, count_new_chunks
         return list(documents), metadatas, ids, count_new_chunks
 
 
-    def _format_result(self, results):
+    @staticmethod
+    def _format_result(results):
         return [
         return [
             (Document(page_content=result[0], metadata=result[1] or {}), result[2])
             (Document(page_content=result[0], metadata=result[1] or {}), result[2])
             for result in zip(
             for result in zip(

+ 3 - 3
embedchain/embedder/base.py

@@ -15,8 +15,8 @@ class EmbeddingFunc(EmbeddingFunction):
     def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
     def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
         self.embedding_fn = embedding_fn
         self.embedding_fn = embedding_fn
 
 
-    def __call__(self, input: Embeddable) -> Embeddings:
-        return self.embedding_fn(input)
+    def __call__(self, input_: Embeddable) -> Embeddings:
+        return self.embedding_fn(input_)
 
 
 
 
 class BaseEmbedder:
 class BaseEmbedder:
@@ -29,7 +29,7 @@ class BaseEmbedder:
 
 
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
         """
         """
-        Intialize the embedder class.
+        Initialize the embedder class.
 
 
         :param config: embedder configuration option class, defaults to None
         :param config: embedder configuration option class, defaults to None
         :type config: Optional[BaseEmbedderConfig], optional
         :type config: Optional[BaseEmbedderConfig], optional

+ 2 - 2
embedchain/embedder/google.py

@@ -13,11 +13,11 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
         super().__init__()
         super().__init__()
         self.config = config or GoogleAIEmbedderConfig()
         self.config = config or GoogleAIEmbedderConfig()
 
 
-    def __call__(self, input: str) -> Embeddings:
+    def __call__(self, input_: str) -> Embeddings:
         model = self.config.model
         model = self.config.model
         title = self.config.title
         title = self.config.title
         task_type = self.config.task_type
         task_type = self.config.task_type
-        embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title)
+        embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
         return embeddings["embedding"]
         return embeddings["embedding"]
 
 
 
 

+ 1 - 1
embedchain/helpers/json_serializable.py

@@ -42,7 +42,7 @@ class JSONSerializable:
     A class to represent a JSON serializable object.
     A class to represent a JSON serializable object.
 
 
     This class provides methods to serialize and deserialize objects,
     This class provides methods to serialize and deserialize objects,
-    as well as save serialized objects to a file and load them back.
+    as well as to save serialized objects to a file and load them back.
     """
     """
 
 
     _deserializable_classes = set()  # Contains classes that are whitelisted for deserialization.
     _deserializable_classes = set()  # Contains classes that are whitelisted for deserialization.

+ 8 - 7
embedchain/llm/base.py

@@ -4,9 +4,7 @@ from typing import Any, Dict, Generator, List, Optional
 from langchain.schema import BaseMessage as LCBaseMessage
 from langchain.schema import BaseMessage as LCBaseMessage
 
 
 from embedchain.config import BaseLlmConfig
 from embedchain.config import BaseLlmConfig
-from embedchain.config.llm.base import (DEFAULT_PROMPT,
-                                        DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
-                                        DOCS_SITE_PROMPT_TEMPLATE)
+from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.memory.base import ChatHistory
 from embedchain.memory.base import ChatHistory
 from embedchain.memory.message import ChatMessage
 from embedchain.memory.message import ChatMessage
@@ -76,7 +74,7 @@ class BaseLlm(JSONSerializable):
         :return: The prompt
         :return: The prompt
         :rtype: str
         :rtype: str
         """
         """
-        context_string = (" | ").join(contexts)
+        context_string = " | ".join(contexts)
         web_search_result = kwargs.get("web_search_result", "")
         web_search_result = kwargs.get("web_search_result", "")
         if web_search_result:
         if web_search_result:
             context_string = self._append_search_and_context(context_string, web_search_result)
             context_string = self._append_search_and_context(context_string, web_search_result)
@@ -110,7 +108,8 @@ class BaseLlm(JSONSerializable):
             prompt = self.config.prompt.substitute(context=context_string, query=input_query)
             prompt = self.config.prompt.substitute(context=context_string, query=input_query)
         return prompt
         return prompt
 
 
-    def _append_search_and_context(self, context: str, web_search_result: str) -> str:
+    @staticmethod
+    def _append_search_and_context(context: str, web_search_result: str) -> str:
         """Append web search context to existing context
         """Append web search context to existing context
 
 
         :param context: Existing context
         :param context: Existing context
@@ -134,7 +133,8 @@ class BaseLlm(JSONSerializable):
         """
         """
         return self.get_llm_model_answer(prompt)
         return self.get_llm_model_answer(prompt)
 
 
-    def access_search_and_get_results(self, input_query: str):
+    @staticmethod
+    def access_search_and_get_results(input_query: str):
         """
         """
         Search the internet for additional context
         Search the internet for additional context
 
 
@@ -153,7 +153,8 @@ class BaseLlm(JSONSerializable):
         logging.info(f"Access search to get answers for {input_query}")
         logging.info(f"Access search to get answers for {input_query}")
         return search.run(input_query)
         return search.run(input_query)
 
 
-    def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
+    @staticmethod
+    def _stream_response(answer: Any) -> Generator[Any, Any, None]:
         """Generator to be used as streaming response
         """Generator to be used as streaming response
 
 
         :param answer: Answer chunk from llm
         :param answer: Answer chunk from llm

+ 1 - 1
embedchain/llm/google.py

@@ -44,7 +44,7 @@ class GoogleLlm(BaseLlm):
             "temperature": self.config.temperature or 0.5,
             "temperature": self.config.temperature or 0.5,
         }
         }
 
 
-        if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
+        if 0.0 <= self.config.top_p <= 1.0:
             generation_config_params["top_p"] = self.config.top_p
             generation_config_params["top_p"] = self.config.top_p
         else:
         else:
             raise ValueError("`top_p` must be > 0.0 and < 1.0")
             raise ValueError("`top_p` must be > 0.0 and < 1.0")

+ 1 - 1
embedchain/llm/huggingface.py

@@ -48,7 +48,7 @@ class HuggingFaceLlm(BaseLlm):
             "max_new_tokens": config.max_tokens,
             "max_new_tokens": config.max_tokens,
         }
         }
 
 
-        if config.top_p > 0.0 and config.top_p < 1.0:
+        if 0.0 < config.top_p < 1.0:
             model_kwargs["top_p"] = config.top_p
             model_kwargs["top_p"] = config.top_p
         else:
         else:
             raise ValueError("`top_p` must be > 0.0 and < 1.0")
             raise ValueError("`top_p` must be > 0.0 and < 1.0")

+ 2 - 1
embedchain/llm/ollama.py

@@ -20,7 +20,8 @@ class OllamaLlm(BaseLlm):
     def get_llm_model_answer(self, prompt):
     def get_llm_model_answer(self, prompt):
         return self._get_answer(prompt=prompt, config=self.config)
         return self._get_answer(prompt=prompt, config=self.config)
 
 
-    def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
+    @staticmethod
+    def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
         callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
         callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
 
 
         llm = Ollama(
         llm = Ollama(

+ 1 - 1
embedchain/loaders/base_loader.py

@@ -5,7 +5,7 @@ class BaseLoader(JSONSerializable):
     def __init__(self):
     def __init__(self):
         pass
         pass
 
 
-    def load_data():
+    def load_data(self, url):
         """
         """
         Implemented by child classes
         Implemented by child classes
         """
         """

+ 1 - 1
embedchain/loaders/directory_loader.py

@@ -32,7 +32,7 @@ class DirectoryLoader(BaseLoader):
         doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
         doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
 
 
         for error in self.errors:
         for error in self.errors:
-            logging.warn(error)
+            logging.warning(error)
 
 
         return {"doc_id": doc_id, "data": data_list}
         return {"doc_id": doc_id, "data": data_list}
 
 

+ 2 - 1
embedchain/loaders/docs_site_loader.py

@@ -49,7 +49,8 @@ class DocsSiteLoader(BaseLoader):
         urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
         urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
         return urls
         return urls
 
 
-    def _load_data_from_url(self, url):
+    @staticmethod
+    def _load_data_from_url(url: str) -> list:
         response = requests.get(url)
         response = requests.get(url)
         if response.status_code != 200:
         if response.status_code != 200:
             logging.info(f"Failed to fetch the website: {response.status_code}")
             logging.info(f"Failed to fetch the website: {response.status_code}")

+ 19 - 15
embedchain/loaders/github.py

@@ -18,7 +18,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
 
 
 
 
 class GithubLoader(BaseLoader):
 class GithubLoader(BaseLoader):
-    """Load data from github search query."""
+    """Load data from GitHub search query."""
 
 
     def __init__(self, config: Optional[Dict[str, Any]] = None):
     def __init__(self, config: Optional[Dict[str, Any]] = None):
         super().__init__()
         super().__init__()
@@ -48,7 +48,7 @@ class GithubLoader(BaseLoader):
             self.client = None
             self.client = None
 
 
     def _github_search_code(self, query: str):
     def _github_search_code(self, query: str):
-        """Search github code."""
+        """Search GitHub code."""
         data = []
         data = []
         results = self.client.search_code(query)
         results = self.client.search_code(query)
         for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
         for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
@@ -66,7 +66,8 @@ class GithubLoader(BaseLoader):
             )
             )
         return data
         return data
 
 
-    def _get_github_repo_data(self, repo_url: str):
+    @staticmethod
+    def _get_github_repo_data(repo_url: str):
         local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
         local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
         local_path = f"/tmp/{local_hash}"
         local_path = f"/tmp/{local_hash}"
         data = []
         data = []
@@ -121,14 +122,14 @@ class GithubLoader(BaseLoader):
 
 
         return data
         return data
 
 
-    def _github_search_repo(self, query: str):
-        """Search github repo."""
+    def _github_search_repo(self, query: str) -> list[dict]:
+        """Search GitHub repo."""
         data = []
         data = []
         logging.info(f"Searching github repos with query: {query}")
         logging.info(f"Searching github repos with query: {query}")
         results = self.client.search_repositories(query)
         results = self.client.search_repositories(query)
         # Add repo urls and descriptions
         # Add repo urls and descriptions
         urls = list(map(lambda x: x.html_url, results))
         urls = list(map(lambda x: x.html_url, results))
-        discriptions = list(map(lambda x: x.description, results))
+        descriptions = list(map(lambda x: x.description, results))
         data.append(
         data.append(
             {
             {
                 "content": clean_string(desc),
                 "content": clean_string(desc),
@@ -136,7 +137,7 @@ class GithubLoader(BaseLoader):
                     "url": url,
                     "url": url,
                 },
                 },
             }
             }
-            for url, desc in zip(urls, discriptions)
+            for url, desc in zip(urls, descriptions)
         )
         )
 
 
         # Add repo contents
         # Add repo contents
@@ -146,8 +147,8 @@ class GithubLoader(BaseLoader):
             data = self._get_github_repo_data(clone_url)
             data = self._get_github_repo_data(clone_url)
         return data
         return data
 
 
-    def _github_search_issues_and_pr(self, query: str, type: str):
-        """Search github issues and PRs."""
+    def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
+        """Search GitHub issues and PRs."""
         data = []
         data = []
 
 
         query = f"{query} is:{type}"
         query = f"{query} is:{type}"
@@ -161,7 +162,7 @@ class GithubLoader(BaseLoader):
             title = result.title
             title = result.title
             body = result.body
             body = result.body
             if not body:
             if not body:
-                logging.warn(f"Skipping issue because empty content for: {url}")
+                logging.warning(f"Skipping issue because empty content for: {url}")
                 continue
                 continue
             labels = " ".join([label.name for label in result.labels])
             labels = " ".join([label.name for label in result.labels])
             issue_comments = result.get_comments()
             issue_comments = result.get_comments()
@@ -186,7 +187,7 @@ class GithubLoader(BaseLoader):
 
 
     # need to test more for discussion
     # need to test more for discussion
     def _github_search_discussions(self, query: str):
     def _github_search_discussions(self, query: str):
-        """Search github discussions."""
+        """Search GitHub discussions."""
         data = []
         data = []
 
 
         query = f"{query} is:discussion"
         query = f"{query} is:discussion"
@@ -202,7 +203,7 @@ class GithubLoader(BaseLoader):
                     title = discussion.title
                     title = discussion.title
                     body = discussion.body
                     body = discussion.body
                     if not body:
                     if not body:
-                        logging.warn(f"Skipping discussion because empty content for: {url}")
+                        logging.warning(f"Skipping discussion because empty content for: {url}")
                         continue
                         continue
                     comments = []
                     comments = []
                     comments_created_at = []
                     comments_created_at = []
@@ -233,11 +234,14 @@ class GithubLoader(BaseLoader):
             data = self._github_search_issues_and_pr(query, search_type)
             data = self._github_search_issues_and_pr(query, search_type)
         elif search_type == "discussion":
         elif search_type == "discussion":
             raise ValueError("GithubLoader does not support searching discussions yet.")
             raise ValueError("GithubLoader does not support searching discussions yet.")
+        else:
+            raise NotImplementedError(f"{search_type} not supported")
 
 
         return data
         return data
 
 
-    def _get_valid_github_query(self, query: str):
-        """Check if query is valid and return search types and valid github query."""
+    @staticmethod
+    def _get_valid_github_query(query: str):
+        """Check if query is valid and return search types and valid GitHub query."""
         query_terms = shlex.split(query)
         query_terms = shlex.split(query)
         # query must provide repo to load data from
         # query must provide repo to load data from
         if len(query_terms) < 1 or "repo:" not in query:
         if len(query_terms) < 1 or "repo:" not in query:
@@ -273,7 +277,7 @@ class GithubLoader(BaseLoader):
         return types, query
         return types, query
 
 
     def load_data(self, search_query: str, max_results: int = 1000):
     def load_data(self, search_query: str, max_results: int = 1000):
-        """Load data from github search query."""
+        """Load data from GitHub search query."""
 
 
         if not self.client:
         if not self.client:
             raise ValueError(
             raise ValueError(

+ 2 - 1
embedchain/loaders/image.py

@@ -20,7 +20,8 @@ class ImageLoader(BaseLoader):
         self.api_key = api_key or os.environ["OPENAI_API_KEY"]
         self.api_key = api_key or os.environ["OPENAI_API_KEY"]
         self.client = OpenAI(api_key=self.api_key)
         self.client = OpenAI(api_key=self.api_key)
 
 
-    def _encode_image(self, image_path: str):
+    @staticmethod
+    def _encode_image(image_path: str):
         with open(image_path, "rb") as image_file:
         with open(image_path, "rb") as image_file:
             return base64.b64encode(image_file.read()).decode("utf-8")
             return base64.b64encode(image_file.read()).decode("utf-8")
 
 

+ 2 - 1
embedchain/loaders/json.py

@@ -15,7 +15,8 @@ class JSONReader:
         """Initialize the JSONReader."""
         """Initialize the JSONReader."""
         pass
         pass
 
 
-    def load_data(self, json_data: Union[Dict, str]) -> List[str]:
+    @staticmethod
+    def load_data(json_data: Union[Dict, str]) -> List[str]:
         """Load data from a JSON structure.
         """Load data from a JSON structure.
 
 
         Args:
         Args:

+ 2 - 1
embedchain/loaders/mysql.py

@@ -39,7 +39,8 @@ class MySQLLoader(BaseLoader):
                     Refer `https://docs.embedchain.ai/data-sources/mysql`.",
                     Refer `https://docs.embedchain.ai/data-sources/mysql`.",
             )
             )
 
 
-    def _check_query(self, query):
+    @staticmethod
+    def _check_query(query):
         if not isinstance(query, str):
         if not isinstance(query, str):
             raise ValueError(
             raise ValueError(
                 f"Invalid mysql query: {query}",
                 f"Invalid mysql query: {query}",

+ 2 - 2
embedchain/loaders/postgres.py

@@ -24,7 +24,6 @@ class PostgresLoader(BaseLoader):
                     Run `pip install --upgrade 'embedchain[postgres]'`"
                     Run `pip install --upgrade 'embedchain[postgres]'`"
             ) from e
             ) from e
 
 
-        config_info = ""
         if "url" in config:
         if "url" in config:
             config_info = config.get("url")
             config_info = config.get("url")
         else:
         else:
@@ -37,7 +36,8 @@ class PostgresLoader(BaseLoader):
         self.connection = psycopg.connect(conninfo=config_info)
         self.connection = psycopg.connect(conninfo=config_info)
         self.cursor = self.connection.cursor()
         self.cursor = self.connection.cursor()
 
 
-    def _check_query(self, query):
+    @staticmethod
+    def _check_query(query):
         if not isinstance(query, str):
         if not isinstance(query, str):
             raise ValueError(
             raise ValueError(
                 f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",  # noqa:E501
                 f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",  # noqa:E501

+ 2 - 1
embedchain/loaders/slack.py

@@ -56,7 +56,8 @@ class SlackLoader(BaseLoader):
         )
         )
         logging.info("Slack Loader setup successful!")
         logging.info("Slack Loader setup successful!")
 
 
-    def _check_query(self, query):
+    @staticmethod
+    def _check_query(query):
         if not isinstance(query, str):
         if not isinstance(query, str):
             raise ValueError(
             raise ValueError(
                 f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501
                 f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501

+ 1 - 1
embedchain/loaders/unstructured_file.py

@@ -8,7 +8,7 @@ from embedchain.utils.misc import clean_string
 @register_deserializable
 @register_deserializable
 class UnstructuredLoader(BaseLoader):
 class UnstructuredLoader(BaseLoader):
     def load_data(self, url):
     def load_data(self, url):
-        """Load data from a Unstructured file."""
+        """Load data from an Unstructured file."""
         try:
         try:
             from langchain.document_loaders import UnstructuredFileLoader
             from langchain.document_loaders import UnstructuredFileLoader
         except ImportError:
         except ImportError:

+ 5 - 4
embedchain/loaders/web_page.py

@@ -21,7 +21,7 @@ class WebPageLoader(BaseLoader):
     _session = requests.Session()
     _session = requests.Session()
 
 
     def load_data(self, url):
     def load_data(self, url):
-        """Load data from a web page using a shared requests session."""
+        """Load data from a web page using a shared requests' session."""
         response = self._session.get(url, timeout=30)
         response = self._session.get(url, timeout=30)
         response.raise_for_status()
         response.raise_for_status()
         data = response.content
         data = response.content
@@ -40,7 +40,8 @@ class WebPageLoader(BaseLoader):
             ],
             ],
         }
         }
 
 
-    def _get_clean_content(self, html, url) -> str:
+    @staticmethod
+    def _get_clean_content(html, url) -> str:
         soup = BeautifulSoup(html, "html.parser")
         soup = BeautifulSoup(html, "html.parser")
         original_size = len(str(soup.get_text()))
         original_size = len(str(soup.get_text()))
 
 
@@ -60,8 +61,8 @@ class WebPageLoader(BaseLoader):
             tag.decompose()
             tag.decompose()
 
 
         ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
         ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
-        for id in ids_to_exclude:
-            tags = soup.find_all(id=id)
+        for id_ in ids_to_exclude:
+            tags = soup.find_all(id=id_)
             for tag in tags:
             for tag in tags:
                 tag.decompose()
                 tag.decompose()
 
 

+ 4 - 2
embedchain/memory/base.py

@@ -113,10 +113,12 @@ class ChatHistory:
         count = self.cursor.fetchone()[0]
         count = self.cursor.fetchone()[0]
         return count
         return count
 
 
-    def _serialize_json(self, metadata: Dict[str, Any]):
+    @staticmethod
+    def _serialize_json(metadata: Dict[str, Any]):
         return json.dumps(metadata)
         return json.dumps(metadata)
 
 
-    def _deserialize_json(self, metadata: str):
+    @staticmethod
+    def _deserialize_json(metadata: str):
         return json.loads(metadata)
         return json.loads(metadata)
 
 
     def close_connection(self):
     def close_connection(self):

+ 2 - 2
embedchain/memory/message.py

@@ -54,7 +54,7 @@ class ChatMessage(JSONSerializable):
         if self.human_message:
         if self.human_message:
             logging.info(
             logging.info(
                 "Human message already exists in the chat message,\
                 "Human message already exists in the chat message,\
-                overwritting it with new message."
+                overwriting it with new message."
             )
             )
 
 
         self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
         self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
@@ -63,7 +63,7 @@ class ChatMessage(JSONSerializable):
         if self.ai_message:
         if self.ai_message:
             logging.info(
             logging.info(
                 "AI message already exists in the chat message,\
                 "AI message already exists in the chat message,\
-                overwritting it with new message."
+                overwriting it with new message."
             )
             )
 
 
         self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)
         self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)

+ 1 - 1
embedchain/memory/utils.py

@@ -7,7 +7,7 @@ def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str
 
 
     Args:
     Args:
         left (Dict[str, Any]): metadata of human message
         left (Dict[str, Any]): metadata of human message
-        right (Dict[str, Any]): metadata of ai message
+        right (Dict[str, Any]): metadata of AI message
 
 
     Returns:
     Returns:
         Dict[str, Any]: combined metadata dict with dedup
         Dict[str, Any]: combined metadata dict with dedup

+ 5 - 3
embedchain/store/assistants.py

@@ -19,7 +19,7 @@ from embedchain.utils.misc import detect_datatype
 
 
 logging.basicConfig(level=logging.WARN)
 logging.basicConfig(level=logging.WARN)
 
 
-# Setup the user directory if doesn't exist already
+# Set up the user directory if it doesn't exist already
 Client.setup_dir()
 Client.setup_dir()
 
 
 
 
@@ -130,12 +130,14 @@ class OpenAIAssistant:
         messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
         messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
         return list(messages)
         return list(messages)
 
 
-    def _format_message(self, thread_message):
+    @staticmethod
+    def _format_message(thread_message):
         thread_message = cast(ThreadMessage, thread_message)
         thread_message = cast(ThreadMessage, thread_message)
         content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
         content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
         return " ".join(content)
         return " ".join(content)
 
 
-    def _save_temp_data(self, data, source):
+    @staticmethod
+    def _save_temp_data(data, source):
         special_chars_pattern = r'[\\/:*?"<>|&=% ]+'
         special_chars_pattern = r'[\\/:*?"<>|&=% ]+'
         sanitized_source = re.sub(special_chars_pattern, "_", source)[:256]
         sanitized_source = re.sub(special_chars_pattern, "_", source)[:256]
         temp_dir = tempfile.mkdtemp()
         temp_dir = tempfile.mkdtemp()

+ 2 - 1
embedchain/telemetry/posthog.py

@@ -38,7 +38,8 @@ class AnonymousTelemetry:
         posthog_logger = logging.getLogger("posthog")
         posthog_logger = logging.getLogger("posthog")
         posthog_logger.disabled = True
         posthog_logger.disabled = True
 
 
-    def _get_user_id(self):
+    @staticmethod
+    def _get_user_id():
         os.makedirs(CONFIG_DIR, exist_ok=True)
         os.makedirs(CONFIG_DIR, exist_ok=True)
         if os.path.exists(CONFIG_FILE):
         if os.path.exists(CONFIG_FILE):
             with open(CONFIG_FILE, "r") as f:
             with open(CONFIG_FILE, "r") as f:

+ 2 - 3
embedchain/utils/misc.py

@@ -201,8 +201,7 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
     formatted_source = format_source(str(source), 30)
 
 
     if url:
     if url:
-        from langchain.document_loaders.youtube import \
-            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -345,7 +344,7 @@ def detect_datatype(source: Any) -> DataType:
             return DataType.TEXT_FILE
             return DataType.TEXT_FILE
 
 
         # If the source is a valid file, that's not detectable as a type, an error is raised.
         # If the source is a valid file, that's not detectable as a type, an error is raised.
-        # It does not fallback to text.
+        # It does not fall back to text.
         raise ValueError(
         raise ValueError(
             "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)."  # noqa: E501
             "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)."  # noqa: E501
         )
         )

+ 1 - 1
embedchain/vectordb/base.py

@@ -49,7 +49,7 @@ class BaseVectorDB(JSONSerializable):
         raise NotImplementedError
         raise NotImplementedError
 
 
     def query(self):
     def query(self):
-        """Query contents from vector data base based on vector similarity"""
+        """Query contents from vector database based on vector similarity"""
         raise NotImplementedError
         raise NotImplementedError
 
 
     def count(self) -> int:
     def count(self) -> int:

+ 4 - 2
embedchain/vectordb/chroma.py

@@ -75,7 +75,8 @@ class ChromaDB(BaseVectorDB):
         """Called during initialization"""
         """Called during initialization"""
         return self.client
         return self.client
 
 
-    def _generate_where_clause(self, where: Dict[str, any]) -> str:
+    @staticmethod
+    def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]:
         # If only one filter is supplied, return it as is
         # If only one filter is supplied, return it as is
         # (no need to wrap in $and based on chroma docs)
         # (no need to wrap in $and based on chroma docs)
         if len(where.keys()) <= 1:
         if len(where.keys()) <= 1:
@@ -160,7 +161,8 @@ class ChromaDB(BaseVectorDB):
                 ids=ids[i : i + self.BATCH_SIZE],
                 ids=ids[i : i + self.BATCH_SIZE],
             )
             )
 
 
-    def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
+    @staticmethod
+    def _format_result(results: QueryResult) -> list[tuple[Document, float]]:
         """
         """
         Format Chroma results
         Format Chroma results
 
 

+ 2 - 2
embedchain/vectordb/elasticsearch.py

@@ -88,7 +88,7 @@ class ElasticsearchDB(BaseVectorDB):
         """
         """
         Get existing doc ids present in vector database
         Get existing doc ids present in vector database
 
 
-        :param ids: _list of doc ids to check for existance
+        :param ids: _list of doc ids to check for existence
         :type ids: List[str]
         :type ids: List[str]
         :param where: to filter data
         :param where: to filter data
         :type where: Dict[str, any]
         :type where: Dict[str, any]
@@ -161,7 +161,7 @@ class ElasticsearchDB(BaseVectorDB):
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
-        query contents from vector data base based on vector similarity
+        query contents from vector database based on vector similarity
 
 
         :param input_query: list of query string
         :param input_query: list of query string
         :type input_query: List[str]
         :type input_query: List[str]

+ 1 - 1
embedchain/vectordb/opensearch.py

@@ -163,7 +163,7 @@ class OpenSearchDB(BaseVectorDB):
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
-        query contents from vector data base based on vector similarity
+        query contents from vector database based on vector similarity
 
 
         :param input_query: list of query string
         :param input_query: list of query string
         :type input_query: List[str]
         :type input_query: List[str]

+ 2 - 1
embedchain/vectordb/weaviate.py

@@ -305,7 +305,8 @@ class WeaviateDB(BaseVectorDB):
         """
         """
         return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
         return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
 
 
-    def _query_with_cursor(self, query, cursor):
+    @staticmethod
+    def _query_with_cursor(query, cursor):
         if cursor is not None:
         if cursor is not None:
             query.with_after(cursor)
             query.with_after(cursor)
         results = query.do()
         results = query.do()

+ 5 - 6
embedchain/vectordb/zilliz.py

@@ -6,8 +6,7 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.base import BaseVectorDB
 
 
 try:
 try:
-    from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
-                          MilvusClient, connections, utility)
+    from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusClient, connections, utility
 except ImportError:
 except ImportError:
     raise ImportError(
     raise ImportError(
         "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
         "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
@@ -97,10 +96,10 @@ class ZillizVectorDB(BaseVectorDB):
         if ids is None or len(ids) == 0 or self.collection.num_entities == 0:
         if ids is None or len(ids) == 0 or self.collection.num_entities == 0:
             return {"ids": []}
             return {"ids": []}
 
 
-        if not (self.collection.is_empty):
-            filter = f"id in {ids}"
+        if not self.collection.is_empty:
+            filter_ = f"id in {ids}"
             results = self.client.query(
             results = self.client.query(
-                collection_name=self.config.collection_name, filter=filter, output_fields=["id"]
+                collection_name=self.config.collection_name, filter=filter_, output_fields=["id"]
             )
             )
             results = [res["id"] for res in results]
             results = [res["id"] for res in results]
 
 
@@ -134,7 +133,7 @@ class ZillizVectorDB(BaseVectorDB):
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
         """
-        Query contents from vector data base based on vector similarity
+        Query contents from vector database based on vector similarity
 
 
         :param input_query: list of query string
         :param input_query: list of query string
         :type input_query: List[str]
         :type input_query: List[str]

+ 2 - 1
tests/chunkers/test_text.py

@@ -69,7 +69,8 @@ class TestTextChunker:
 
 
 
 
 class MockLoader:
 class MockLoader:
-    def load_data(self, src):
+    @staticmethod
+    def load_data(src) -> dict:
         """
         """
         Mock loader that returns a list of data dictionaries.
         Mock loader that returns a list of data dictionaries.
         Adjust this method to return different data for testing.
         Adjust this method to return different data for testing.