瀏覽代碼

[Bug fix] Fix typos, static methods and other sanity improvements in the package (#1129)

Sandra Serrano 1 年之前
父節點
當前提交
2496ed133e
共有 41 個文件被更改,包括 133 次插入103 次删除
  1. 9 4
      embedchain/app.py
  2. 6 5
      embedchain/chunkers/base_chunker.py
  3. 1 1
      embedchain/client.py
  4. 3 2
      embedchain/config/add_config.py
  5. 6 3
      embedchain/config/cache_config.py
  6. 6 4
      embedchain/config/llm/base.py
  7. 2 2
      embedchain/config/vectordb/qdrant.py
  8. 1 1
      embedchain/config/vectordb/zilliz.py
  9. 2 1
      embedchain/data_formatter/data_formatter.py
  10. 7 9
      embedchain/embedchain.py
  11. 3 3
      embedchain/embedder/base.py
  12. 2 2
      embedchain/embedder/google.py
  13. 1 1
      embedchain/helpers/json_serializable.py
  14. 8 7
      embedchain/llm/base.py
  15. 1 1
      embedchain/llm/google.py
  16. 1 1
      embedchain/llm/huggingface.py
  17. 2 1
      embedchain/llm/ollama.py
  18. 1 1
      embedchain/loaders/base_loader.py
  19. 1 1
      embedchain/loaders/directory_loader.py
  20. 2 1
      embedchain/loaders/docs_site_loader.py
  21. 19 15
      embedchain/loaders/github.py
  22. 2 1
      embedchain/loaders/image.py
  23. 2 1
      embedchain/loaders/json.py
  24. 2 1
      embedchain/loaders/mysql.py
  25. 2 2
      embedchain/loaders/postgres.py
  26. 2 1
      embedchain/loaders/slack.py
  27. 1 1
      embedchain/loaders/unstructured_file.py
  28. 5 4
      embedchain/loaders/web_page.py
  29. 4 2
      embedchain/memory/base.py
  30. 2 2
      embedchain/memory/message.py
  31. 1 1
      embedchain/memory/utils.py
  32. 5 3
      embedchain/store/assistants.py
  33. 2 1
      embedchain/telemetry/posthog.py
  34. 2 3
      embedchain/utils/misc.py
  35. 1 1
      embedchain/vectordb/base.py
  36. 4 2
      embedchain/vectordb/chroma.py
  37. 2 2
      embedchain/vectordb/elasticsearch.py
  38. 1 1
      embedchain/vectordb/opensearch.py
  39. 2 1
      embedchain/vectordb/weaviate.py
  40. 5 6
      embedchain/vectordb/zilliz.py
  41. 2 1
      tests/chunkers/test_text.py

+ 9 - 4
embedchain/app.py

@@ -9,9 +9,14 @@ from typing import Any, Dict, Optional
 import requests
 import yaml
 
-from embedchain.cache import (Config, ExactMatchEvaluation,
-                              SearchDistanceEvaluation, cache,
-                              gptcache_data_manager, gptcache_pre_function)
+from embedchain.cache import (
+    Config,
+    ExactMatchEvaluation,
+    SearchDistanceEvaluation,
+    cache,
+    gptcache_data_manager,
+    gptcache_pre_function,
+)
 from embedchain.client import Client
 from embedchain.config import AppConfig, CacheConfig, ChunkerConfig
 from embedchain.constants import SQLITE_PATH
@@ -27,7 +32,7 @@ from embedchain.utils.misc import validate_config
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.chroma import ChromaDB
 
-# Setup the user directory if doesn't exist already
+# Set up the user directory if it doesn't exist already
 Client.setup_dir()
 
 

+ 6 - 5
embedchain/chunkers/base_chunker.py

@@ -17,7 +17,7 @@ class BaseChunker(JSONSerializable):
         """
         Loads data and chunks it.
 
-        :param loader: The loader which's `load_data` method is used to create
+        :param loader: The loader whose `load_data` method is used to create
         the raw data.
         :param src: The data to be handled by the loader. Can be a URL for
         remote sources or local content for local loaders.
@@ -25,7 +25,7 @@ class BaseChunker(JSONSerializable):
         """
         documents = []
         chunk_ids = []
-        idMap = {}
+        id_map = {}
         min_chunk_size = config.min_chunk_size if config is not None else 1
         logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
@@ -49,8 +49,8 @@ class BaseChunker(JSONSerializable):
             for chunk in chunks:
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
-                if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
-                    idMap[chunk_id] = True
+                if id_map.get(chunk_id) is None and len(chunk) >= min_chunk_size:
+                    id_map[chunk_id] = True
                     chunk_ids.append(chunk_id)
                     documents.append(chunk)
                     metadatas.append(meta_data)
@@ -77,5 +77,6 @@ class BaseChunker(JSONSerializable):
 
         # TODO: This should be done during initialization. This means it has to be done in the child classes.
 
-    def get_word_count(self, documents):
+    @staticmethod
+    def get_word_count(documents) -> int:
         return sum([len(document.split(" ")) for document in documents])

+ 1 - 1
embedchain/client.py

@@ -31,7 +31,7 @@ class Client:
                 )
 
     @classmethod
-    def setup_dir(self):
+    def setup_dir(cls):
         """
         Loads the user id from the config file if it exists, otherwise generates a new
         one and saves it to the config file.

+ 3 - 2
embedchain/config/add_config.py

@@ -26,7 +26,7 @@ class ChunkerConfig(BaseConfig):
         if self.min_chunk_size >= self.chunk_size:
             raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
         if self.min_chunk_size < self.chunk_overlap:
-            logging.warn(
+            logging.warning(
                 f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant."  # noqa:E501
             )
 
@@ -35,7 +35,8 @@ class ChunkerConfig(BaseConfig):
         else:
             self.length_function = length_function if length_function else len
 
-    def load_func(self, dotpath: str):
+    @staticmethod
+    def load_func(dotpath: str):
         if "." not in dotpath:
             return getattr(builtins, dotpath)
         else:

+ 6 - 3
embedchain/config/cache_config.py

@@ -10,12 +10,12 @@ class CacheSimilarityEvalConfig(BaseConfig):
     This is the evaluator to compare two embeddings according to their distance computed in embedding retrieval stage.
     In the retrieval stage, `search_result` is the distance used for approximate nearest neighbor search and have been
     put into `cache_dict`. `max_distance` is used to bound this distance to make it between [0-`max_distance`].
-    `positive` is used to indicate this distance is directly proportional to the similarity of two entites.
-    If `positive` is set `False`, `max_distance` will be used to substract this distance to get the final score.
+    `positive` is used to indicate this distance is directly proportional to the similarity of two entities.
+    If `positive` is set `False`, `max_distance` will be used to subtract this distance to get the final score.
 
     :param max_distance: the bound of maximum distance.
     :type max_distance: float
-    :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise it is False.
+    :param positive: if the larger distance indicates more similar of two entities, It is True. Otherwise, it is False.
     :type positive: bool
     """
 
@@ -29,6 +29,7 @@ class CacheSimilarityEvalConfig(BaseConfig):
         self.max_distance = max_distance
         self.positive = positive
 
+    @staticmethod
     def from_config(config: Optional[Dict[str, Any]]):
         if config is None:
             return CacheSimilarityEvalConfig()
@@ -63,6 +64,7 @@ class CacheInitConfig(BaseConfig):
         self.similarity_threshold = similarity_threshold
         self.auto_flush = auto_flush
 
+    @staticmethod
     def from_config(config: Optional[Dict[str, Any]]):
         if config is None:
             return CacheInitConfig()
@@ -83,6 +85,7 @@ class CacheConfig(BaseConfig):
         self.similarity_eval_config = similarity_eval_config
         self.init_config = init_config
 
+    @staticmethod
     def from_config(config: Optional[Dict[str, Any]]):
         if config is None:
             return CacheConfig()

+ 6 - 4
embedchain/config/llm/base.py

@@ -155,24 +155,26 @@ class BaseLlmConfig(BaseConfig):
         self.stream = stream
         self.where = where
 
-    def validate_prompt(self, prompt: Template) -> bool:
+    @staticmethod
+    def validate_prompt(prompt: Template) -> Optional[re.Match[str]]:
         """
         validate the prompt
 
         :param prompt: the prompt to validate
         :type prompt: Template
         :return: valid (true) or invalid (false)
-        :rtype: bool
+        :rtype: Optional[re.Match[str]]
         """
         return re.search(query_re, prompt.template) and re.search(context_re, prompt.template)
 
-    def _validate_prompt_history(self, prompt: Template) -> bool:
+    @staticmethod
+    def _validate_prompt_history(prompt: Template) -> Optional[re.Match[str]]:
         """
         validate the prompt with history
 
         :param prompt: the prompt to validate
         :type prompt: Template
         :return: valid (true) or invalid (false)
-        :rtype: bool
+        :rtype: Optional[re.Match[str]]
         """
         return re.search(history_re, prompt.template)

+ 2 - 2
embedchain/config/vectordb/qdrant.py

@@ -7,8 +7,8 @@ from embedchain.helpers.json_serializable import register_deserializable
 @register_deserializable
 class QdrantDBConfig(BaseVectorDbConfig):
     """
-    Config to initialize an qdrant client.
-    :param url. qdrant url or list of nodes url to be used for connection
+    Config to initialize a qdrant client.
+    :param: url. qdrant url or list of nodes url to be used for connection
     """
 
     def __init__(

+ 1 - 1
embedchain/config/vectordb/zilliz.py

@@ -26,7 +26,7 @@ class ZillizDBConfig(BaseVectorDbConfig):
         :param uri: Cluster endpoint obtained from the Zilliz Console, defaults to None
         :type uri: Optional[str], optional
         :param token: API Key, if a Serverless Cluster, username:password, if a Dedicated Cluster, defaults to None
-        :type port: Optional[str], optional
+        :type token: Optional[str], optional
         """
         self.uri = uri or os.environ.get("ZILLIZ_CLOUD_URI")
         if not self.uri:

+ 2 - 1
embedchain/data_formatter/data_formatter.py

@@ -34,7 +34,8 @@ class DataFormatter(JSONSerializable):
         self.loader = self._get_loader(data_type=data_type, config=config.loader, loader=loader)
         self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, chunker=chunker)
 
-    def _lazy_load(self, module_path: str):
+    @staticmethod
+    def _lazy_load(module_path: str):
         module_path, class_name = module_path.rsplit(".", 1)
         module = import_module(module_path)
         return getattr(module, class_name)

+ 7 - 9
embedchain/embedchain.py

@@ -7,9 +7,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 
-from embedchain.cache import (adapt, get_gptcache_session,
-                              gptcache_data_convert,
-                              gptcache_update_cache_callback)
+from embedchain.cache import adapt, get_gptcache_session, gptcache_data_convert, gptcache_update_cache_callback
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.base_app_config import BaseAppConfig
@@ -19,8 +17,7 @@ from embedchain.embedder.base import BaseEmbedder
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import (DataType, DirectDataType,
-                                         IndirectDataType, SpecialDataType)
+from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.utils.misc import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
@@ -84,7 +81,7 @@ class EmbedChain(JSONSerializable):
         # Attributes that aren't subclass related.
         self.user_asks = []
 
-        self.chunker: ChunkerConfig = None
+        self.chunker: Optional[ChunkerConfig] = None
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -290,7 +287,7 @@ class EmbedChain(JSONSerializable):
             #   Or it's different, then it will be added as a new text.
             return None
         elif chunker.data_type.value in [item.value for item in IndirectDataType]:
-            # These types have a indirect source reference
+            # These types have an indirect source reference
             # As long as the reference is the same, they can be updated.
             where = {"url": src}
             if chunker.data_type == DataType.JSON and is_valid_json_string(src):
@@ -442,10 +439,11 @@ class EmbedChain(JSONSerializable):
         )
         count_new_chunks = self.db.count() - chunks_before_addition
 
-        print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
+        print(f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}")
         return list(documents), metadatas, ids, count_new_chunks
 
-    def _format_result(self, results):
+    @staticmethod
+    def _format_result(results):
         return [
             (Document(page_content=result[0], metadata=result[1] or {}), result[2])
             for result in zip(

+ 3 - 3
embedchain/embedder/base.py

@@ -15,8 +15,8 @@ class EmbeddingFunc(EmbeddingFunction):
     def __init__(self, embedding_fn: Callable[[list[str]], list[str]]):
         self.embedding_fn = embedding_fn
 
-    def __call__(self, input: Embeddable) -> Embeddings:
-        return self.embedding_fn(input)
+    def __call__(self, input_: Embeddable) -> Embeddings:
+        return self.embedding_fn(input_)
 
 
 class BaseEmbedder:
@@ -29,7 +29,7 @@ class BaseEmbedder:
 
     def __init__(self, config: Optional[BaseEmbedderConfig] = None):
         """
-        Intialize the embedder class.
+        Initialize the embedder class.
 
         :param config: embedder configuration option class, defaults to None
         :type config: Optional[BaseEmbedderConfig], optional

+ 2 - 2
embedchain/embedder/google.py

@@ -13,11 +13,11 @@ class GoogleAIEmbeddingFunction(EmbeddingFunction):
         super().__init__()
         self.config = config or GoogleAIEmbedderConfig()
 
-    def __call__(self, input: str) -> Embeddings:
+    def __call__(self, input_: str) -> Embeddings:
         model = self.config.model
         title = self.config.title
         task_type = self.config.task_type
-        embeddings = genai.embed_content(model=model, content=input, task_type=task_type, title=title)
+        embeddings = genai.embed_content(model=model, content=input_, task_type=task_type, title=title)
         return embeddings["embedding"]
 
 

+ 1 - 1
embedchain/helpers/json_serializable.py

@@ -42,7 +42,7 @@ class JSONSerializable:
     A class to represent a JSON serializable object.
 
     This class provides methods to serialize and deserialize objects,
-    as well as save serialized objects to a file and load them back.
+    as well as to save serialized objects to a file and load them back.
     """
 
     _deserializable_classes = set()  # Contains classes that are whitelisted for deserialization.

+ 8 - 7
embedchain/llm/base.py

@@ -4,9 +4,7 @@ from typing import Any, Dict, Generator, List, Optional
 from langchain.schema import BaseMessage as LCBaseMessage
 
 from embedchain.config import BaseLlmConfig
-from embedchain.config.llm.base import (DEFAULT_PROMPT,
-                                        DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
-                                        DOCS_SITE_PROMPT_TEMPLATE)
+from embedchain.config.llm.base import DEFAULT_PROMPT, DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.memory.base import ChatHistory
 from embedchain.memory.message import ChatMessage
@@ -76,7 +74,7 @@ class BaseLlm(JSONSerializable):
         :return: The prompt
         :rtype: str
         """
-        context_string = (" | ").join(contexts)
+        context_string = " | ".join(contexts)
         web_search_result = kwargs.get("web_search_result", "")
         if web_search_result:
             context_string = self._append_search_and_context(context_string, web_search_result)
@@ -110,7 +108,8 @@ class BaseLlm(JSONSerializable):
             prompt = self.config.prompt.substitute(context=context_string, query=input_query)
         return prompt
 
-    def _append_search_and_context(self, context: str, web_search_result: str) -> str:
+    @staticmethod
+    def _append_search_and_context(context: str, web_search_result: str) -> str:
         """Append web search context to existing context
 
         :param context: Existing context
@@ -134,7 +133,8 @@ class BaseLlm(JSONSerializable):
         """
         return self.get_llm_model_answer(prompt)
 
-    def access_search_and_get_results(self, input_query: str):
+    @staticmethod
+    def access_search_and_get_results(input_query: str):
         """
         Search the internet for additional context
 
@@ -153,7 +153,8 @@ class BaseLlm(JSONSerializable):
         logging.info(f"Access search to get answers for {input_query}")
         return search.run(input_query)
 
-    def _stream_response(self, answer: Any) -> Generator[Any, Any, None]:
+    @staticmethod
+    def _stream_response(answer: Any) -> Generator[Any, Any, None]:
         """Generator to be used as streaming response
 
         :param answer: Answer chunk from llm

+ 1 - 1
embedchain/llm/google.py

@@ -44,7 +44,7 @@ class GoogleLlm(BaseLlm):
             "temperature": self.config.temperature or 0.5,
         }
 
-        if self.config.top_p >= 0.0 and self.config.top_p <= 1.0:
+        if 0.0 <= self.config.top_p <= 1.0:
             generation_config_params["top_p"] = self.config.top_p
         else:
             raise ValueError("`top_p` must be > 0.0 and < 1.0")

+ 1 - 1
embedchain/llm/huggingface.py

@@ -48,7 +48,7 @@ class HuggingFaceLlm(BaseLlm):
             "max_new_tokens": config.max_tokens,
         }
 
-        if config.top_p > 0.0 and config.top_p < 1.0:
+        if 0.0 < config.top_p < 1.0:
             model_kwargs["top_p"] = config.top_p
         else:
             raise ValueError("`top_p` must be > 0.0 and < 1.0")

+ 2 - 1
embedchain/llm/ollama.py

@@ -20,7 +20,8 @@ class OllamaLlm(BaseLlm):
     def get_llm_model_answer(self, prompt):
         return self._get_answer(prompt=prompt, config=self.config)
 
-    def _get_answer(self, prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
+    @staticmethod
+    def _get_answer(prompt: str, config: BaseLlmConfig) -> Union[str, Iterable]:
         callback_manager = [StreamingStdOutCallbackHandler()] if config.stream else [StdOutCallbackHandler()]
 
         llm = Ollama(

+ 1 - 1
embedchain/loaders/base_loader.py

@@ -5,7 +5,7 @@ class BaseLoader(JSONSerializable):
     def __init__(self):
         pass
 
-    def load_data():
+    def load_data(self, url):
         """
         Implemented by child classes
         """

+ 1 - 1
embedchain/loaders/directory_loader.py

@@ -32,7 +32,7 @@ class DirectoryLoader(BaseLoader):
         doc_id = hashlib.sha256((str(data_list) + str(directory_path)).encode()).hexdigest()
 
         for error in self.errors:
-            logging.warn(error)
+            logging.warning(error)
 
         return {"doc_id": doc_id, "data": data_list}
 

+ 2 - 1
embedchain/loaders/docs_site_loader.py

@@ -49,7 +49,8 @@ class DocsSiteLoader(BaseLoader):
         urls = [link for link in self.visited_links if urlparse(link).netloc == urlparse(url).netloc]
         return urls
 
-    def _load_data_from_url(self, url):
+    @staticmethod
+    def _load_data_from_url(url: str) -> list:
         response = requests.get(url)
         if response.status_code != 200:
             logging.info(f"Failed to fetch the website: {response.status_code}")

+ 19 - 15
embedchain/loaders/github.py

@@ -18,7 +18,7 @@ VALID_SEARCH_TYPES = set(["code", "repo", "pr", "issue", "discussion"])
 
 
 class GithubLoader(BaseLoader):
-    """Load data from github search query."""
+    """Load data from GitHub search query."""
 
     def __init__(self, config: Optional[Dict[str, Any]] = None):
         super().__init__()
@@ -48,7 +48,7 @@ class GithubLoader(BaseLoader):
             self.client = None
 
     def _github_search_code(self, query: str):
-        """Search github code."""
+        """Search GitHub code."""
         data = []
         results = self.client.search_code(query)
         for result in tqdm(results, total=results.totalCount, desc="Loading code files from github"):
@@ -66,7 +66,8 @@ class GithubLoader(BaseLoader):
             )
         return data
 
-    def _get_github_repo_data(self, repo_url: str):
+    @staticmethod
+    def _get_github_repo_data(repo_url: str):
         local_hash = hashlib.sha256(repo_url.encode()).hexdigest()
         local_path = f"/tmp/{local_hash}"
         data = []
@@ -121,14 +122,14 @@ class GithubLoader(BaseLoader):
 
         return data
 
-    def _github_search_repo(self, query: str):
-        """Search github repo."""
+    def _github_search_repo(self, query: str) -> list[dict]:
+        """Search GitHub repo."""
         data = []
         logging.info(f"Searching github repos with query: {query}")
         results = self.client.search_repositories(query)
         # Add repo urls and descriptions
         urls = list(map(lambda x: x.html_url, results))
-        discriptions = list(map(lambda x: x.description, results))
+        descriptions = list(map(lambda x: x.description, results))
         data.append(
             {
                 "content": clean_string(desc),
@@ -136,7 +137,7 @@ class GithubLoader(BaseLoader):
                     "url": url,
                 },
             }
-            for url, desc in zip(urls, discriptions)
+            for url, desc in zip(urls, descriptions)
         )
 
         # Add repo contents
@@ -146,8 +147,8 @@ class GithubLoader(BaseLoader):
             data = self._get_github_repo_data(clone_url)
         return data
 
-    def _github_search_issues_and_pr(self, query: str, type: str):
-        """Search github issues and PRs."""
+    def _github_search_issues_and_pr(self, query: str, type: str) -> list[dict]:
+        """Search GitHub issues and PRs."""
         data = []
 
         query = f"{query} is:{type}"
@@ -161,7 +162,7 @@ class GithubLoader(BaseLoader):
             title = result.title
             body = result.body
             if not body:
-                logging.warn(f"Skipping issue because empty content for: {url}")
+                logging.warning(f"Skipping issue because empty content for: {url}")
                 continue
             labels = " ".join([label.name for label in result.labels])
             issue_comments = result.get_comments()
@@ -186,7 +187,7 @@ class GithubLoader(BaseLoader):
 
     # need to test more for discussion
     def _github_search_discussions(self, query: str):
-        """Search github discussions."""
+        """Search GitHub discussions."""
         data = []
 
         query = f"{query} is:discussion"
@@ -202,7 +203,7 @@ class GithubLoader(BaseLoader):
                     title = discussion.title
                     body = discussion.body
                     if not body:
-                        logging.warn(f"Skipping discussion because empty content for: {url}")
+                        logging.warning(f"Skipping discussion because empty content for: {url}")
                         continue
                     comments = []
                     comments_created_at = []
@@ -233,11 +234,14 @@ class GithubLoader(BaseLoader):
             data = self._github_search_issues_and_pr(query, search_type)
         elif search_type == "discussion":
             raise ValueError("GithubLoader does not support searching discussions yet.")
+        else:
+            raise NotImplementedError(f"{search_type} not supported")
 
         return data
 
-    def _get_valid_github_query(self, query: str):
-        """Check if query is valid and return search types and valid github query."""
+    @staticmethod
+    def _get_valid_github_query(query: str):
+        """Check if query is valid and return search types and valid GitHub query."""
         query_terms = shlex.split(query)
         # query must provide repo to load data from
         if len(query_terms) < 1 or "repo:" not in query:
@@ -273,7 +277,7 @@ class GithubLoader(BaseLoader):
         return types, query
 
     def load_data(self, search_query: str, max_results: int = 1000):
-        """Load data from github search query."""
+        """Load data from GitHub search query."""
 
         if not self.client:
             raise ValueError(

+ 2 - 1
embedchain/loaders/image.py

@@ -20,7 +20,8 @@ class ImageLoader(BaseLoader):
         self.api_key = api_key or os.environ["OPENAI_API_KEY"]
         self.client = OpenAI(api_key=self.api_key)
 
-    def _encode_image(self, image_path: str):
+    @staticmethod
+    def _encode_image(image_path: str):
         with open(image_path, "rb") as image_file:
             return base64.b64encode(image_file.read()).decode("utf-8")
 

+ 2 - 1
embedchain/loaders/json.py

@@ -15,7 +15,8 @@ class JSONReader:
         """Initialize the JSONReader."""
         pass
 
-    def load_data(self, json_data: Union[Dict, str]) -> List[str]:
+    @staticmethod
+    def load_data(json_data: Union[Dict, str]) -> List[str]:
         """Load data from a JSON structure.
 
         Args:

+ 2 - 1
embedchain/loaders/mysql.py

@@ -39,7 +39,8 @@ class MySQLLoader(BaseLoader):
                     Refer `https://docs.embedchain.ai/data-sources/mysql`.",
             )
 
-    def _check_query(self, query):
+    @staticmethod
+    def _check_query(query):
         if not isinstance(query, str):
             raise ValueError(
                 f"Invalid mysql query: {query}",

+ 2 - 2
embedchain/loaders/postgres.py

@@ -24,7 +24,6 @@ class PostgresLoader(BaseLoader):
                     Run `pip install --upgrade 'embedchain[postgres]'`"
             ) from e
 
-        config_info = ""
         if "url" in config:
             config_info = config.get("url")
         else:
@@ -37,7 +36,8 @@ class PostgresLoader(BaseLoader):
         self.connection = psycopg.connect(conninfo=config_info)
         self.cursor = self.connection.cursor()
 
-    def _check_query(self, query):
+    @staticmethod
+    def _check_query(query):
         if not isinstance(query, str):
             raise ValueError(
                 f"Invalid postgres query: {query}. Provide the valid source to add from postgres, make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",  # noqa:E501

+ 2 - 1
embedchain/loaders/slack.py

@@ -56,7 +56,8 @@ class SlackLoader(BaseLoader):
         )
         logging.info("Slack Loader setup successful!")
 
-    def _check_query(self, query):
+    @staticmethod
+    def _check_query(query):
         if not isinstance(query, str):
             raise ValueError(
                 f"Invalid query passed to Slack loader, found: {query}. Check `https://docs.embedchain.ai/data-sources/slack` to learn more."  # noqa:E501

+ 1 - 1
embedchain/loaders/unstructured_file.py

@@ -8,7 +8,7 @@ from embedchain.utils.misc import clean_string
 @register_deserializable
 class UnstructuredLoader(BaseLoader):
     def load_data(self, url):
-        """Load data from a Unstructured file."""
+        """Load data from an Unstructured file."""
         try:
             from langchain.document_loaders import UnstructuredFileLoader
         except ImportError:

+ 5 - 4
embedchain/loaders/web_page.py

@@ -21,7 +21,7 @@ class WebPageLoader(BaseLoader):
     _session = requests.Session()
 
     def load_data(self, url):
-        """Load data from a web page using a shared requests session."""
+        """Load data from a web page using a shared requests' session."""
         response = self._session.get(url, timeout=30)
         response.raise_for_status()
         data = response.content
@@ -40,7 +40,8 @@ class WebPageLoader(BaseLoader):
             ],
         }
 
-    def _get_clean_content(self, html, url) -> str:
+    @staticmethod
+    def _get_clean_content(html, url) -> str:
         soup = BeautifulSoup(html, "html.parser")
         original_size = len(str(soup.get_text()))
 
@@ -60,8 +61,8 @@ class WebPageLoader(BaseLoader):
             tag.decompose()
 
         ids_to_exclude = ["sidebar", "main-navigation", "menu-main-menu"]
-        for id in ids_to_exclude:
-            tags = soup.find_all(id=id)
+        for id_ in ids_to_exclude:
+            tags = soup.find_all(id=id_)
             for tag in tags:
                 tag.decompose()
 

+ 4 - 2
embedchain/memory/base.py

@@ -113,10 +113,12 @@ class ChatHistory:
         count = self.cursor.fetchone()[0]
         return count
 
-    def _serialize_json(self, metadata: Dict[str, Any]):
+    @staticmethod
+    def _serialize_json(metadata: Dict[str, Any]):
         return json.dumps(metadata)
 
-    def _deserialize_json(self, metadata: str):
+    @staticmethod
+    def _deserialize_json(metadata: str):
         return json.loads(metadata)
 
     def close_connection(self):

+ 2 - 2
embedchain/memory/message.py

@@ -54,7 +54,7 @@ class ChatMessage(JSONSerializable):
         if self.human_message:
             logging.info(
                 "Human message already exists in the chat message,\
-                overwritting it with new message."
+                overwriting it with new message."
             )
 
         self.human_message = BaseMessage(content=message, created_by="human", metadata=metadata)
@@ -63,7 +63,7 @@ class ChatMessage(JSONSerializable):
         if self.ai_message:
             logging.info(
                 "AI message already exists in the chat message,\
-                overwritting it with new message."
+                overwriting it with new message."
             )
 
         self.ai_message = BaseMessage(content=message, created_by="ai", metadata=metadata)

+ 1 - 1
embedchain/memory/utils.py

@@ -7,7 +7,7 @@ def merge_metadata_dict(left: Optional[Dict[str, Any]], right: Optional[Dict[str
 
     Args:
         left (Dict[str, Any]): metadata of human message
-        right (Dict[str, Any]): metadata of ai message
+        right (Dict[str, Any]): metadata of AI message
 
     Returns:
         Dict[str, Any]: combined metadata dict with dedup

+ 5 - 3
embedchain/store/assistants.py

@@ -19,7 +19,7 @@ from embedchain.utils.misc import detect_datatype
 
 logging.basicConfig(level=logging.WARN)
 
-# Setup the user directory if doesn't exist already
+# Set up the user directory if it doesn't exist already
 Client.setup_dir()
 
 
@@ -130,12 +130,14 @@ class OpenAIAssistant:
         messages = self._client.beta.threads.messages.list(thread_id=self.thread_id, order="desc")
         return list(messages)
 
-    def _format_message(self, thread_message):
+    @staticmethod
+    def _format_message(thread_message):
         thread_message = cast(ThreadMessage, thread_message)
         content = [c.text.value for c in thread_message.content if isinstance(c, MessageContentText)]
         return " ".join(content)
 
-    def _save_temp_data(self, data, source):
+    @staticmethod
+    def _save_temp_data(data, source):
         special_chars_pattern = r'[\\/:*?"<>|&=% ]+'
         sanitized_source = re.sub(special_chars_pattern, "_", source)[:256]
         temp_dir = tempfile.mkdtemp()

+ 2 - 1
embedchain/telemetry/posthog.py

@@ -38,7 +38,8 @@ class AnonymousTelemetry:
         posthog_logger = logging.getLogger("posthog")
         posthog_logger.disabled = True
 
-    def _get_user_id(self):
+    @staticmethod
+    def _get_user_id():
         os.makedirs(CONFIG_DIR, exist_ok=True)
         if os.path.exists(CONFIG_FILE):
             with open(CONFIG_FILE, "r") as f:

+ 2 - 3
embedchain/utils/misc.py

@@ -201,8 +201,7 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
 
     if url:
-        from langchain.document_loaders.youtube import \
-            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")
@@ -345,7 +344,7 @@ def detect_datatype(source: Any) -> DataType:
             return DataType.TEXT_FILE
 
         # If the source is a valid file, that's not detectable as a type, an error is raised.
-        # It does not fallback to text.
+        # It does not fall back to text.
         raise ValueError(
             "Source points to a valid file, but based on the filename, no `data_type` can be detected. Please be aware, that not all data_types allow conventional file references, some require the use of the `file URI scheme`. Please refer to the embedchain documentation (https://docs.embedchain.ai/advanced/data_types#remote-data-types)."  # noqa: E501
         )

+ 1 - 1
embedchain/vectordb/base.py

@@ -49,7 +49,7 @@ class BaseVectorDB(JSONSerializable):
         raise NotImplementedError
 
     def query(self):
-        """Query contents from vector data base based on vector similarity"""
+        """Query contents from vector database based on vector similarity"""
         raise NotImplementedError
 
     def count(self) -> int:

+ 4 - 2
embedchain/vectordb/chroma.py

@@ -75,7 +75,8 @@ class ChromaDB(BaseVectorDB):
         """Called during initialization"""
         return self.client
 
-    def _generate_where_clause(self, where: Dict[str, any]) -> str:
+    @staticmethod
+    def _generate_where_clause(where: Dict[str, any]) -> Dict[str, any]:
         # If only one filter is supplied, return it as is
         # (no need to wrap in $and based on chroma docs)
         if len(where.keys()) <= 1:
@@ -160,7 +161,8 @@ class ChromaDB(BaseVectorDB):
                 ids=ids[i : i + self.BATCH_SIZE],
             )
 
-    def _format_result(self, results: QueryResult) -> list[tuple[Document, float]]:
+    @staticmethod
+    def _format_result(results: QueryResult) -> list[tuple[Document, float]]:
         """
         Format Chroma results
 

+ 2 - 2
embedchain/vectordb/elasticsearch.py

@@ -88,7 +88,7 @@ class ElasticsearchDB(BaseVectorDB):
         """
         Get existing doc ids present in vector database
 
-        :param ids: _list of doc ids to check for existance
+        :param ids: _list of doc ids to check for existence
         :type ids: List[str]
         :param where: to filter data
         :type where: Dict[str, any]
@@ -161,7 +161,7 @@ class ElasticsearchDB(BaseVectorDB):
         **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
-        query contents from vector data base based on vector similarity
+        query contents from vector database based on vector similarity
 
         :param input_query: list of query string
         :type input_query: List[str]

+ 1 - 1
embedchain/vectordb/opensearch.py

@@ -163,7 +163,7 @@ class OpenSearchDB(BaseVectorDB):
         **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
-        query contents from vector data base based on vector similarity
+        query contents from vector database based on vector similarity
 
         :param input_query: list of query string
         :type input_query: List[str]

+ 2 - 1
embedchain/vectordb/weaviate.py

@@ -305,7 +305,8 @@ class WeaviateDB(BaseVectorDB):
         """
         return f"{self.config.collection_name}_{self.embedder.vector_dimension}".capitalize()
 
-    def _query_with_cursor(self, query, cursor):
+    @staticmethod
+    def _query_with_cursor(query, cursor):
         if cursor is not None:
             query.with_after(cursor)
         results = query.do()

+ 5 - 6
embedchain/vectordb/zilliz.py

@@ -6,8 +6,7 @@ from embedchain.helpers.json_serializable import register_deserializable
 from embedchain.vectordb.base import BaseVectorDB
 
 try:
-    from pymilvus import (Collection, CollectionSchema, DataType, FieldSchema,
-                          MilvusClient, connections, utility)
+    from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusClient, connections, utility
 except ImportError:
     raise ImportError(
         "Zilliz requires extra dependencies. Install with `pip install --upgrade embedchain[milvus]`"
@@ -97,10 +96,10 @@ class ZillizVectorDB(BaseVectorDB):
         if ids is None or len(ids) == 0 or self.collection.num_entities == 0:
             return {"ids": []}
 
-        if not (self.collection.is_empty):
-            filter = f"id in {ids}"
+        if not self.collection.is_empty:
+            filter_ = f"id in {ids}"
             results = self.client.query(
-                collection_name=self.config.collection_name, filter=filter, output_fields=["id"]
+                collection_name=self.config.collection_name, filter=filter_, output_fields=["id"]
             )
             results = [res["id"] for res in results]
 
@@ -134,7 +133,7 @@ class ZillizVectorDB(BaseVectorDB):
         **kwargs: Optional[Dict[str, Any]],
     ) -> Union[List[Tuple[str, Dict]], List[str]]:
         """
-        Query contents from vector data base based on vector similarity
+        Query contents from vector database based on vector similarity
 
         :param input_query: list of query string
         :type input_query: List[str]

+ 2 - 1
tests/chunkers/test_text.py

@@ -69,7 +69,8 @@ class TestTextChunker:
 
 
 class MockLoader:
-    def load_data(self, src):
+    @staticmethod
+    def load_data(src) -> dict:
         """
         Mock loader that returns a list of data dictionaries.
         Adjust this method to return different data for testing.