Pārlūkot izejas kodu

Resolve conflicts (#208)

Deshraj Yadav 2 gadi atpakaļ
vecāks
revīzija
9ca836520f

+ 24 - 0
Makefile

@@ -0,0 +1,24 @@
+# Variables
+PYTHON := python3
+PIP := $(PYTHON) -m pip
+PROJECT_NAME := embedchain
+
+# Targets
+.PHONY: install format lint clean test
+
+install:
+	$(PIP) install --upgrade pip
+	$(PIP) install .[dev]
+
+format:
+	$(PYTHON) -m black .
+	$(PYTHON) -m isort .
+
+lint:
+	$(PYTHON) -m ruff .
+
+clean:
+	rm -rf dist build *.egg-info
+
+test:
+	$(PYTHON) -m pytest

+ 22 - 0
README.md

@@ -44,6 +44,7 @@ embedchain is a framework to easily create LLM powered bots over any dataset. If
     - [Reset](#reset)
     - [Reset](#reset)
     - [Count](#count)
     - [Count](#count)
 - [How does it work?](#how-does-it-work)
 - [How does it work?](#how-does-it-work)
+- [Contribution Guidelines](#contribution-guidelines)
 - [Tech Stack](#tech-stack)
 - [Tech Stack](#tech-stack)
 - [Team](#team)
 - [Team](#team)
   - [Author](#author)
   - [Author](#author)
@@ -551,6 +552,27 @@ embedchain is a framework which takes care of all these nuances and provides a s
 
 
 In the first release, we are making it easier for anyone to get a chatbot over any dataset up and running in less than a minute. All you need to do is create an app instance, add the data sets using `.add` function and then use `.query` function to get the relevant answer.
 In the first release, we are making it easier for anyone to get a chatbot over any dataset up and running in less than a minute. All you need to do is create an app instance, add the data sets using `.add` function and then use `.query` function to get the relevant answer.
 
 
+# Contribution Guidelines
+
+Thank you for your interest in contributing to the EmbedChain project! We welcome your ideas and contributions to help improve the project. Please follow the instructions below to get started:
+
+1. **Fork the repository**: Click on the "Fork" button at the top right corner of this repository page. This will create a copy of the repository in your own GitHub account.
+
+2. **Install the required dependencies**: Ensure that you have the necessary dependencies installed in your Python environment. You can do this by running the following command:
+
+```bash
+make install
+```
+
+3. **Make changes in the code**: Create a new branch in your forked repository and make your desired changes in the codebase.
+4. **Format code**: Before creating a pull request, it's important to ensure that your code follows our formatting guidelines. Run the following commands to format the code:
+
+```bash
+make lint format
+```
+
+5. **Create a pull request**: When you are ready to contribute your changes, submit a pull request to the EmbedChain repository. Provide a clear and descriptive title for your pull request, along with a detailed description of the changes you have made.
+
 # Tech Stack
 # Tech Stack
 
 
 embedchain is built on the following stack:
 embedchain is built on the following stack:

+ 1 - 1
embedchain/__init__.py

@@ -1 +1 @@
-from .embedchain import App, OpenSourceApp, PersonApp, PersonOpenSourceApp
+from .embedchain import App, OpenSourceApp, PersonApp, PersonOpenSourceApp

+ 6 - 4
embedchain/chunkers/base_chunker.py

@@ -3,15 +3,17 @@ import hashlib
 
 
 class BaseChunker:
 class BaseChunker:
     def __init__(self, text_splitter):
     def __init__(self, text_splitter):
-        ''' Initialize the chunker. '''
+        """Initialize the chunker."""
         self.text_splitter = text_splitter
         self.text_splitter = text_splitter
 
 
     def create_chunks(self, loader, src):
     def create_chunks(self, loader, src):
         """
         """
         Loads data and chunks it.
         Loads data and chunks it.
 
 
-        :param loader: The loader which's `load_data` method is used to create the raw data.
-        :param src: The data to be handled by the loader. Can be a URL for remote sources or local content for local loaders. 
+        :param loader: The loader which's `load_data` method is used to create
+        the raw data.
+        :param src: The data to be handled by the loader. Can be a URL for
+        remote sources or local content for local loaders.
         """
         """
         documents = []
         documents = []
         ids = []
         ids = []
@@ -27,7 +29,7 @@ class BaseChunker:
 
 
             for chunk in chunks:
             for chunk in chunks:
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
-                if (idMap.get(chunk_id) is None):
+                if idMap.get(chunk_id) is None:
                     idMap[chunk_id] = True
                     idMap[chunk_id] = True
                     ids.append(chunk_id)
                     ids.append(chunk_id)
                     documents.append(chunk)
                     documents.append(chunk)

+ 4 - 4
embedchain/chunkers/docx_file.py

@@ -1,10 +1,9 @@
 from typing import Optional
 from typing import Optional
-from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config.AddConfig import ChunkerConfig
 
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
-
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 
 TEXT_SPLITTER_CHUNK_PARAMS = {
 TEXT_SPLITTER_CHUNK_PARAMS = {
     "chunk_size": 1000,
     "chunk_size": 1000,
@@ -14,7 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 
 
 class DocxFileChunker(BaseChunker):
 class DocxFileChunker(BaseChunker):
-    ''' Chunker for .docx file. '''
+    """Chunker for .docx file."""
+
     def __init__(self, config: Optional[ChunkerConfig] = None):
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
         if config is None:
             config = TEXT_SPLITTER_CHUNK_PARAMS
             config = TEXT_SPLITTER_CHUNK_PARAMS

+ 4 - 3
embedchain/chunkers/pdf_file.py

@@ -1,9 +1,9 @@
 from typing import Optional
 from typing import Optional
-from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config.AddConfig import ChunkerConfig
 
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 
 TEXT_SPLITTER_CHUNK_PARAMS = {
 TEXT_SPLITTER_CHUNK_PARAMS = {
     "chunk_size": 1000,
     "chunk_size": 1000,
@@ -13,7 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 
 
 class PdfFileChunker(BaseChunker):
 class PdfFileChunker(BaseChunker):
-    ''' Chunker for PDF file. '''
+    """Chunker for PDF file."""
+
     def __init__(self, config: Optional[ChunkerConfig] = None):
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
         if config is None:
             config = TEXT_SPLITTER_CHUNK_PARAMS
             config = TEXT_SPLITTER_CHUNK_PARAMS

+ 4 - 3
embedchain/chunkers/qna_pair.py

@@ -1,9 +1,9 @@
 from typing import Optional
 from typing import Optional
-from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config.AddConfig import ChunkerConfig
 
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 
 TEXT_SPLITTER_CHUNK_PARAMS = {
 TEXT_SPLITTER_CHUNK_PARAMS = {
     "chunk_size": 300,
     "chunk_size": 300,
@@ -13,7 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 
 
 class QnaPairChunker(BaseChunker):
 class QnaPairChunker(BaseChunker):
-    ''' Chunker for QnA pair. '''
+    """Chunker for QnA pair."""
+
     def __init__(self, config: Optional[ChunkerConfig] = None):
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
         if config is None:
             config = TEXT_SPLITTER_CHUNK_PARAMS
             config = TEXT_SPLITTER_CHUNK_PARAMS

+ 4 - 3
embedchain/chunkers/text.py

@@ -1,9 +1,9 @@
 from typing import Optional
 from typing import Optional
-from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config.AddConfig import ChunkerConfig
 
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 
 TEXT_SPLITTER_CHUNK_PARAMS = {
 TEXT_SPLITTER_CHUNK_PARAMS = {
     "chunk_size": 300,
     "chunk_size": 300,
@@ -13,7 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 
 
 class TextChunker(BaseChunker):
 class TextChunker(BaseChunker):
-    ''' Chunker for text. '''
+    """Chunker for text."""
+
     def __init__(self, config: Optional[ChunkerConfig] = None):
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
         if config is None:
             config = TEXT_SPLITTER_CHUNK_PARAMS
             config = TEXT_SPLITTER_CHUNK_PARAMS

+ 4 - 3
embedchain/chunkers/web_page.py

@@ -1,9 +1,9 @@
 from typing import Optional
 from typing import Optional
-from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config.AddConfig import ChunkerConfig
 
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 
 TEXT_SPLITTER_CHUNK_PARAMS = {
 TEXT_SPLITTER_CHUNK_PARAMS = {
     "chunk_size": 500,
     "chunk_size": 500,
@@ -13,7 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 
 
 class WebPageChunker(BaseChunker):
 class WebPageChunker(BaseChunker):
-    ''' Chunker for web page. '''
+    """Chunker for web page."""
+
     def __init__(self, config: Optional[ChunkerConfig] = None):
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
         if config is None:
             config = TEXT_SPLITTER_CHUNK_PARAMS
             config = TEXT_SPLITTER_CHUNK_PARAMS

+ 4 - 3
embedchain/chunkers/youtube_video.py

@@ -1,9 +1,9 @@
 from typing import Optional
 from typing import Optional
-from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config.AddConfig import ChunkerConfig
 
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 
 TEXT_SPLITTER_CHUNK_PARAMS = {
 TEXT_SPLITTER_CHUNK_PARAMS = {
     "chunk_size": 2000,
     "chunk_size": 2000,
@@ -13,7 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 
 
 class YoutubeVideoChunker(BaseChunker):
 class YoutubeVideoChunker(BaseChunker):
-    ''' Chunker for Youtube video. '''
+    """Chunker for Youtube video."""
+
     def __init__(self, config: Optional[ChunkerConfig] = None):
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
         if config is None:
             config = TEXT_SPLITTER_CHUNK_PARAMS
             config = TEXT_SPLITTER_CHUNK_PARAMS

+ 17 - 7
embedchain/config/AddConfig.py

@@ -1,4 +1,5 @@
 from typing import Callable, Optional
 from typing import Callable, Optional
+
 from embedchain.config.BaseConfig import BaseConfig
 from embedchain.config.BaseConfig import BaseConfig
 
 
 
 
@@ -6,27 +7,36 @@ class ChunkerConfig(BaseConfig):
     """
     """
     Config for the chunker used in `add` method
     Config for the chunker used in `add` method
     """
     """
-    def __init__(self,
-                 chunk_size: Optional[int] = 4000,
-                 chunk_overlap: Optional[int] = 200,
-                 length_function: Optional[Callable[[str], int]] = len):
+
+    def __init__(
+        self,
+        chunk_size: Optional[int] = 4000,
+        chunk_overlap: Optional[int] = 200,
+        length_function: Optional[Callable[[str], int]] = len,
+    ):
         self.chunk_size = chunk_size
         self.chunk_size = chunk_size
         self.chunk_overlap = chunk_overlap
         self.chunk_overlap = chunk_overlap
         self.length_function = length_function
         self.length_function = length_function
 
 
+
 class LoaderConfig(BaseConfig):
 class LoaderConfig(BaseConfig):
     """
     """
     Config for the chunker used in `add` method
     Config for the chunker used in `add` method
     """
     """
+
     def __init__(self):
     def __init__(self):
         pass
         pass
 
 
+
 class AddConfig(BaseConfig):
 class AddConfig(BaseConfig):
     """
     """
     Config for the `add` method.
     Config for the `add` method.
     """
     """
-    def __init__(self,
-                 chunker: Optional[ChunkerConfig] = None,
-                 loader: Optional[LoaderConfig] = None):
+
+    def __init__(
+        self,
+        chunker: Optional[ChunkerConfig] = None,
+        loader: Optional[LoaderConfig] = None,
+    ):
         self.loader = loader
         self.loader = loader
         self.chunker = chunker
         self.chunker = chunker

+ 1 - 0
embedchain/config/BaseConfig.py

@@ -2,6 +2,7 @@ class BaseConfig:
     """
     """
     Base config.
     Base config.
     """
     """
+
     def __init__(self):
     def __init__(self):
         pass
         pass
 
 

+ 18 - 10
embedchain/config/ChatConfig.py

@@ -1,8 +1,10 @@
-from embedchain.config.QueryConfig import QueryConfig
 from string import Template
 from string import Template
 
 
+from embedchain.config.QueryConfig import QueryConfig
+
 DEFAULT_PROMPT = """
 DEFAULT_PROMPT = """
-  You are a chatbot having a conversation with a human. You are given chat history and context.
+  You are a chatbot having a conversation with a human. You are given chat
+  history and context.
   You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know".
   You need to answer the query considering context, chat history and your knowledge base. If you don't know the answer or the answer is neither contained in the context nor in history, then simply say "I don't know".
 
 
   $context
   $context
@@ -12,35 +14,41 @@ DEFAULT_PROMPT = """
   Query: $query
   Query: $query
 
 
   Helpful Answer:
   Helpful Answer:
-"""
+"""  # noqa:E501
 
 
 DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
 DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
 
 
+
 class ChatConfig(QueryConfig):
 class ChatConfig(QueryConfig):
     """
     """
     Config for the `chat` method, inherits from `QueryConfig`.
     Config for the `chat` method, inherits from `QueryConfig`.
     """
     """
+
     def __init__(self, template: Template = None, stream: bool = False):
     def __init__(self, template: Template = None, stream: bool = False):
         """
         """
         Initializes the ChatConfig instance.
         Initializes the ChatConfig instance.
 
 
-        :param template: Optional. The `Template` instance to use as a template for prompt.
-        :param stream: Optional. Control if response is streamed back to the user
-        :raises ValueError: If the template is not valid as template should contain $context and $query and $history
+        :param template: Optional. The `Template` instance to use as a
+        template for prompt.
+        :param stream: Optional. Control if response is streamed back to the
+        user
+        :raises ValueError: If the template is not valid as template should
+        contain $context and $query and $history
         """
         """
         if template is None:
         if template is None:
             template = DEFAULT_PROMPT_TEMPLATE
             template = DEFAULT_PROMPT_TEMPLATE
 
 
-        # History is set as 0 to ensure that there is always a history, that way, there don't have to be two templates.
-        # Having two templates would make it complicated because the history is not user controlled.
+        # History is set as 0 to ensure that there is always a history, that
+        # way, there don't have to be two templates.
+        # Having two templates would make it complicated because the history
+        # is not user controlled.
         super().__init__(template, history=[0], stream=stream)
         super().__init__(template, history=[0], stream=stream)
 
 
     def set_history(self, history):
     def set_history(self, history):
         """
         """
         Chat history is not user provided and not set at initialization time
         Chat history is not user provided and not set at initialization time
-        
+
         :param history: (string) history to set
         :param history: (string) history to set
         """
         """
         self.history = history
         self.history = history
         return
         return
-

+ 11 - 6
embedchain/config/InitConfig.py

@@ -1,8 +1,9 @@
-import os
 import logging
 import logging
+import os
 
 
 from embedchain.config.BaseConfig import BaseConfig
 from embedchain.config.BaseConfig import BaseConfig
 
 
+
 class InitConfig(BaseConfig):
 class InitConfig(BaseConfig):
     """
     """
     Config to initialize an embedchain `App` instance.
     Config to initialize an embedchain `App` instance.
@@ -10,7 +11,8 @@ class InitConfig(BaseConfig):
 
 
     def __init__(self, log_level=None, ef=None, db=None):
     def __init__(self, log_level=None, ef=None, db=None):
         """
         """
-        :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
+        :param log_level: Optional. (String) Debug level
+        ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
         :param ef: Optional. Embedding function to use.
         :param ef: Optional. Embedding function to use.
         :param db: Optional. (Vector) database to use for embeddings.
         :param db: Optional. (Vector) database to use for embeddings.
         """
         """
@@ -19,16 +21,18 @@ class InitConfig(BaseConfig):
         # Embedding Function
         # Embedding Function
         if ef is None:
         if ef is None:
             from chromadb.utils import embedding_functions
             from chromadb.utils import embedding_functions
+
             self.ef = embedding_functions.OpenAIEmbeddingFunction(
             self.ef = embedding_functions.OpenAIEmbeddingFunction(
                 api_key=os.getenv("OPENAI_API_KEY"),
                 api_key=os.getenv("OPENAI_API_KEY"),
                 organization_id=os.getenv("OPENAI_ORGANIZATION"),
                 organization_id=os.getenv("OPENAI_ORGANIZATION"),
-                model_name="text-embedding-ada-002"
+                model_name="text-embedding-ada-002",
             )
             )
         else:
         else:
             self.ef = ef
             self.ef = ef
 
 
         if db is None:
         if db is None:
             from embedchain.vectordb.chroma_db import ChromaDB
             from embedchain.vectordb.chroma_db import ChromaDB
+
             self.db = ChromaDB(ef=self.ef)
             self.db = ChromaDB(ef=self.ef)
         else:
         else:
             self.db = db
             self.db = db
@@ -44,9 +48,10 @@ class InitConfig(BaseConfig):
         if debug_level is not None:
         if debug_level is not None:
             level = getattr(logging, debug_level.upper(), None)
             level = getattr(logging, debug_level.upper(), None)
             if not isinstance(level, int):
             if not isinstance(level, int):
-                raise ValueError(f'Invalid log level: {debug_level}')
+                raise ValueError(f"Invalid log level: {debug_level}")
 
 
-        logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s",
-                            level=level)
+        logging.basicConfig(
+            format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level
+        )
         self.logger = logging.getLogger(__name__)
         self.logger = logging.getLogger(__name__)
         return
         return

+ 22 - 14
embedchain/config/QueryConfig.py

@@ -1,6 +1,7 @@
-from embedchain.config.BaseConfig import BaseConfig
-from string import Template
 import re
 import re
+from string import Template
+
+from embedchain.config.BaseConfig import BaseConfig
 
 
 DEFAULT_PROMPT = """
 DEFAULT_PROMPT = """
   Use the following pieces of context to answer the query at the end.
   Use the following pieces of context to answer the query at the end.
@@ -11,7 +12,7 @@ DEFAULT_PROMPT = """
   Query: $query
   Query: $query
 
 
   Helpful Answer:
   Helpful Answer:
-"""
+"""  # noqa:E501
 
 
 DEFAULT_PROMPT_WITH_HISTORY = """
 DEFAULT_PROMPT_WITH_HISTORY = """
   Use the following pieces of context to answer the query at the end.
   Use the following pieces of context to answer the query at the end.
@@ -25,7 +26,7 @@ DEFAULT_PROMPT_WITH_HISTORY = """
   Query: $query
   Query: $query
 
 
   Helpful Answer:
   Helpful Answer:
-"""
+"""  # noqa:E501
 
 
 DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
 DEFAULT_PROMPT_TEMPLATE = Template(DEFAULT_PROMPT)
 DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
 DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE = Template(DEFAULT_PROMPT_WITH_HISTORY)
@@ -38,14 +39,17 @@ class QueryConfig(BaseConfig):
     """
     """
     Config for the `query` method.
     Config for the `query` method.
     """
     """
-    def __init__(self, template: Template = None, history = None, stream: bool = False):
+
+    def __init__(self, template: Template = None, history=None, stream: bool = False):
         """
         """
         Initializes the QueryConfig instance.
         Initializes the QueryConfig instance.
 
 
-        :param template: Optional. The `Template` instance to use as a template for prompt.
+        :param template: Optional. The `Template` instance to use as a
+        template for prompt.
         :param history: Optional. A list of strings to consider as history.
         :param history: Optional. A list of strings to consider as history.
-        :param stream: Optional. Control if response is streamed back to the user
-        :raises ValueError: If the template is not valid as template should contain $context and $query (and optionally $history).
+        :param stream: Optional. Control if response is streamed back to user
+        :raises ValueError: If the template is not valid as template should
+        contain $context and $query (and optionally $history).
         """
         """
         if not history:
         if not history:
             self.history = None
             self.history = None
@@ -67,12 +71,13 @@ class QueryConfig(BaseConfig):
             if self.history is None:
             if self.history is None:
                 raise ValueError("`template` should have `query` and `context` keys")
                 raise ValueError("`template` should have `query` and `context` keys")
             else:
             else:
-                raise ValueError("`template` should have `query`, `context` and `history` keys")
+                raise ValueError(
+                    "`template` should have `query`, `context` and `history` keys"
+                )
 
 
         if not isinstance(stream, bool):
         if not isinstance(stream, bool):
             raise ValueError("`stream` should be bool")
             raise ValueError("`stream` should be bool")
         self.stream = stream
         self.stream = stream
-                
 
 
     def validate_template(self, template: Template):
     def validate_template(self, template: Template):
         """
         """
@@ -82,9 +87,12 @@ class QueryConfig(BaseConfig):
         :return: Boolean, valid (true) or invalid (false)
         :return: Boolean, valid (true) or invalid (false)
         """
         """
         if self.history is None:
         if self.history is None:
-            return (re.search(query_re, template.template) \
-                and re.search(context_re, template.template))
+            return re.search(query_re, template.template) and re.search(
+                context_re, template.template
+            )
         else:
         else:
-            return (re.search(query_re, template.template) \
+            return (
+                re.search(query_re, template.template)
                 and re.search(context_re, template.template)
                 and re.search(context_re, template.template)
-                and re.search(history_re, template.template))
+                and re.search(history_re, template.template)
+            )

+ 2 - 2
embedchain/config/__init__.py

@@ -1,5 +1,5 @@
-from .BaseConfig import BaseConfig
 from .AddConfig import AddConfig
 from .AddConfig import AddConfig
+from .BaseConfig import BaseConfig
 from .ChatConfig import ChatConfig
 from .ChatConfig import ChatConfig
 from .InitConfig import InitConfig
 from .InitConfig import InitConfig
-from .QueryConfig import QueryConfig
+from .QueryConfig import QueryConfig

+ 1 - 1
embedchain/data_formatter/__init__.py

@@ -1 +1 @@
-from .data_formatter import DataFormatter
+from .data_formatter import DataFormatter

+ 24 - 23
embedchain/data_formatter/data_formatter.py

@@ -1,24 +1,25 @@
-from embedchain.config import AddConfig
-from embedchain.loaders.youtube_video import YoutubeVideoLoader
-from embedchain.loaders.pdf_file import PdfFileLoader
-from embedchain.loaders.web_page import WebPageLoader
-from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
-from embedchain.loaders.local_text import LocalTextLoader
-from embedchain.loaders.docx_file import DocxFileLoader
-from embedchain.chunkers.youtube_video import YoutubeVideoChunker
+from embedchain.chunkers.docx_file import DocxFileChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
-from embedchain.chunkers.web_page import WebPageChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.text import TextChunker
 from embedchain.chunkers.text import TextChunker
-from embedchain.chunkers.docx_file import DocxFileChunker
+from embedchain.chunkers.web_page import WebPageChunker
+from embedchain.chunkers.youtube_video import YoutubeVideoChunker
+from embedchain.config import AddConfig
+from embedchain.loaders.docx_file import DocxFileLoader
+from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
+from embedchain.loaders.local_text import LocalTextLoader
+from embedchain.loaders.pdf_file import PdfFileLoader
+from embedchain.loaders.web_page import WebPageLoader
+from embedchain.loaders.youtube_video import YoutubeVideoLoader
 
 
 
 
 class DataFormatter:
 class DataFormatter:
     """
     """
     DataFormatter is an internal utility class which abstracts the mapping for
     DataFormatter is an internal utility class which abstracts the mapping for
     loaders and chunkers to the data_type entered by the user in their
     loaders and chunkers to the data_type entered by the user in their
-    .add or .add_local method call 
+    .add or .add_local method call
     """
     """
+
     def __init__(self, data_type: str, config: AddConfig):
     def __init__(self, data_type: str, config: AddConfig):
         self.loader = self._get_loader(data_type, config.loader)
         self.loader = self._get_loader(data_type, config.loader)
         self.chunker = self._get_chunker(data_type, config.chunker)
         self.chunker = self._get_chunker(data_type, config.chunker)
@@ -32,12 +33,12 @@ class DataFormatter:
         :raises ValueError: If an unsupported data type is provided.
         :raises ValueError: If an unsupported data type is provided.
         """
         """
         loaders = {
         loaders = {
-            'youtube_video': YoutubeVideoLoader(),
-            'pdf_file': PdfFileLoader(),
-            'web_page': WebPageLoader(),
-            'qna_pair': LocalQnaPairLoader(),
-            'text': LocalTextLoader(),
-            'docx': DocxFileLoader(),
+            "youtube_video": YoutubeVideoLoader(),
+            "pdf_file": PdfFileLoader(),
+            "web_page": WebPageLoader(),
+            "qna_pair": LocalQnaPairLoader(),
+            "text": LocalTextLoader(),
+            "docx": DocxFileLoader(),
         }
         }
         if data_type in loaders:
         if data_type in loaders:
             return loaders[data_type]
             return loaders[data_type]
@@ -53,12 +54,12 @@ class DataFormatter:
         :raises ValueError: If an unsupported data type is provided.
         :raises ValueError: If an unsupported data type is provided.
         """
         """
         chunkers = {
         chunkers = {
-            'youtube_video': YoutubeVideoChunker(config),
-            'pdf_file': PdfFileChunker(config),
-            'web_page': WebPageChunker(config),
-            'qna_pair': QnaPairChunker(config),
-            'text': TextChunker(config),
-            'docx': DocxFileChunker(config),
+            "youtube_video": YoutubeVideoChunker(config),
+            "pdf_file": PdfFileChunker(config),
+            "web_page": WebPageChunker(config),
+            "qna_pair": QnaPairChunker(config),
+            "text": TextChunker(config),
+            "docx": DocxFileChunker(config),
         }
         }
         if data_type in chunkers:
         if data_type in chunkers:
             return chunkers[data_type]
             return chunkers[data_type]

+ 70 - 58
embedchain/embedchain.py

@@ -1,17 +1,16 @@
-import openai
-import os
 import logging
 import logging
+import os
 from string import Template
 from string import Template
 
 
+import openai
 from chromadb.utils import embedding_functions
 from chromadb.utils import embedding_functions
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
-from langchain.embeddings.openai import OpenAIEmbeddings
 from langchain.memory import ConversationBufferMemory
 from langchain.memory import ConversationBufferMemory
-from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
+
+from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
 from embedchain.config.QueryConfig import DEFAULT_PROMPT
 from embedchain.config.QueryConfig import DEFAULT_PROMPT
 from embedchain.data_formatter import DataFormatter
 from embedchain.data_formatter import DataFormatter
-from string import Template
 
 
 gpt4all_model = None
 gpt4all_model = None
 
 
@@ -45,7 +44,8 @@ class EmbedChain:
 
 
         :param data_type: The type of the data to add.
         :param data_type: The type of the data to add.
         :param url: The URL where the data is located.
         :param url: The URL where the data is located.
-        :param config: Optional. The `AddConfig` instance to use as configuration options.
+        :param config: Optional. The `AddConfig` instance to use as configuration
+        options.
         """
         """
         if config is None:
         if config is None:
             config = AddConfig()
             config = AddConfig()
@@ -62,22 +62,28 @@ class EmbedChain:
 
 
         :param data_type: The type of the data to add.
         :param data_type: The type of the data to add.
         :param content: The local data. Refer to the `README` for formatting.
         :param content: The local data. Refer to the `README` for formatting.
-        :param config: Optional. The `AddConfig` instance to use as configuration options.
+        :param config: Optional. The `AddConfig` instance to use as
+        configuration options.
         """
         """
         if config is None:
         if config is None:
             config = AddConfig()
             config = AddConfig()
 
 
         data_formatter = DataFormatter(data_type, config)
         data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([data_type, content])
         self.user_asks.append([data_type, content])
-        self.load_and_embed(data_formatter.loader, data_formatter.chunker, content)
+        self.load_and_embed(
+            data_formatter.loader,
+            data_formatter.chunker,
+            content,
+        )
 
 
     def load_and_embed(self, loader, chunker, src):
     def load_and_embed(self, loader, chunker, src):
         """
         """
-        Loads the data from the given URL, chunks it, and adds it to the database.
+        Loads the data from the given URL, chunks it, and adds it to database.
 
 
         :param loader: The loader to use to load the data.
         :param loader: The loader to use to load the data.
         :param chunker: The chunker to use to chunk the data.
         :param chunker: The chunker to use to chunk the data.
-        :param src: The data to be handled by the loader. Can be a URL for remote sources or local content for local loaders.
+        :param src: The data to be handled by the loader. Can be a URL for
+        remote sources or local content for local loaders.
         """
         """
         embeddings_data = chunker.create_chunks(loader, src)
         embeddings_data = chunker.create_chunks(loader, src)
         documents = embeddings_data["documents"]
         documents = embeddings_data["documents"]
@@ -91,8 +97,12 @@ class EmbedChain:
         existing_ids = set(existing_docs["ids"])
         existing_ids = set(existing_docs["ids"])
 
 
         if len(existing_ids):
         if len(existing_ids):
-            data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
-            data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
+            data_dict = {
+                id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)
+            }
+            data_dict = {
+                id: value for id, value in data_dict.items() if id not in existing_ids
+            }
 
 
             if not data_dict:
             if not data_dict:
                 print(f"All data from {src} already exists in the database.")
                 print(f"All data from {src} already exists in the database.")
@@ -103,12 +113,10 @@ class EmbedChain:
 
 
         chunks_before_addition = self.count()
         chunks_before_addition = self.count()
 
 
-        self.collection.add(
-            documents=documents,
-            metadatas=list(metadatas),
-            ids=ids
+        self.collection.add(documents=documents, metadatas=list(metadatas), ids=ids)
+        print(
+            f"Successfully saved {src}. New chunks count: {self.count() - chunks_before_addition}"  # noqa:E501
         )
         )
-        print(f"Successfully saved {src}. New chunks count: {self.count() - chunks_before_addition}")
 
 
     def _format_result(self, results):
     def _format_result(self, results):
         return [
         return [
@@ -132,7 +140,9 @@ class EmbedChain:
         :return: The content of the document that matched your query.
         :return: The content of the document that matched your query.
         """
         """
         result = self.collection.query(
         result = self.collection.query(
-            query_texts=[input_query,],
+            query_texts=[
+                input_query,
+            ],
             n_results=1,
             n_results=1,
         )
         )
         result_formatted = self._format_result(result)
         result_formatted = self._format_result(result)
@@ -144,17 +154,21 @@ class EmbedChain:
 
 
     def generate_prompt(self, input_query, context, config: QueryConfig):
     def generate_prompt(self, input_query, context, config: QueryConfig):
         """
         """
-        Generates a prompt based on the given query and context, ready to be passed to an LLM
+        Generates a prompt based on the given query and context, ready to be
+        passed to an LLM
 
 
         :param input_query: The query to use.
         :param input_query: The query to use.
         :param context: Similar documents to the query used as context.
         :param context: Similar documents to the query used as context.
-        :param config: Optional. The `QueryConfig` instance to use as configuration options.
+        :param config: Optional. The `QueryConfig` instance to use as
+        configuration options.
         :return: The prompt
         :return: The prompt
         """
         """
         if not config.history:
         if not config.history:
-            prompt = config.template.substitute(context = context, query = input_query)
+            prompt = config.template.substitute(context=context, query=input_query)
         else:
         else:
-            prompt = config.template.substitute(context = context, query = input_query, history = config.history)
+            prompt = config.template.substitute(
+                context=context, query=input_query, history=config.history
+            )
         return prompt
         return prompt
 
 
     def get_answer_from_llm(self, prompt, config: ChatConfig):
     def get_answer_from_llm(self, prompt, config: ChatConfig):
@@ -166,7 +180,7 @@ class EmbedChain:
         :param context: Similar documents to the query used as context.
         :param context: Similar documents to the query used as context.
         :return: The answer.
         :return: The answer.
         """
         """
-        
+
         return self.get_llm_model_answer(prompt, config)
         return self.get_llm_model_answer(prompt, config)
 
 
     def query(self, input_query, config: QueryConfig = None):
     def query(self, input_query, config: QueryConfig = None):
@@ -176,7 +190,8 @@ class EmbedChain:
         LLM as context to get the answer.
         LLM as context to get the answer.
 
 
         :param input_query: The query to use.
         :param input_query: The query to use.
-        :param config: Optional. The `QueryConfig` instance to use as configuration options.
+        :param config: Optional. The `QueryConfig` instance to use as
+        configuration options.
         :return: The answer to the query.
         :return: The answer to the query.
         """
         """
         if config is None:
         if config is None:
@@ -188,7 +203,6 @@ class EmbedChain:
         logging.info(f"Answer: {answer}")
         logging.info(f"Answer: {answer}")
         return answer
         return answer
 
 
-
     def chat(self, input_query, config: ChatConfig = None):
     def chat(self, input_query, config: ChatConfig = None):
         """
         """
         Queries the vector database on the given input query.
         Queries the vector database on the given input query.
@@ -197,30 +211,31 @@ class EmbedChain:
 
 
         Maintains last 5 conversations in memory.
         Maintains last 5 conversations in memory.
         :param input_query: The query to use.
         :param input_query: The query to use.
-        :param config: Optional. The `ChatConfig` instance to use as configuration options.
+        :param config: Optional. The `ChatConfig` instance to use as
+        configuration options.
         :return: The answer to the query.
         :return: The answer to the query.
         """
         """
         context = self.retrieve_from_database(input_query)
         context = self.retrieve_from_database(input_query)
         global memory
         global memory
         chat_history = memory.load_memory_variables({})["history"]
         chat_history = memory.load_memory_variables({})["history"]
-        
+
         if config is None:
         if config is None:
             config = ChatConfig()
             config = ChatConfig()
         if chat_history:
         if chat_history:
             config.set_history(chat_history)
             config.set_history(chat_history)
-            
+
         prompt = self.generate_prompt(input_query, context, config)
         prompt = self.generate_prompt(input_query, context, config)
         logging.info(f"Prompt: {prompt}")
         logging.info(f"Prompt: {prompt}")
         answer = self.get_answer_from_llm(prompt, config)
         answer = self.get_answer_from_llm(prompt, config)
 
 
         memory.chat_memory.add_user_message(input_query)
         memory.chat_memory.add_user_message(input_query)
-        
+
         if isinstance(answer, str):
         if isinstance(answer, str):
             memory.chat_memory.add_ai_message(answer)
             memory.chat_memory.add_ai_message(answer)
             logging.info(f"Answer: {answer}")
             logging.info(f"Answer: {answer}")
             return answer
             return answer
         else:
         else:
-            #this is a streamed response and needs to be handled differently.
+            # this is a streamed response and needs to be handled differently.
             return self._stream_chat_response(answer)
             return self._stream_chat_response(answer)
 
 
     def _stream_chat_response(self, answer):
     def _stream_chat_response(self, answer):
@@ -230,7 +245,6 @@ class EmbedChain:
             yield chunk
             yield chunk
         memory.chat_memory.add_ai_message(streamed_answer)
         memory.chat_memory.add_ai_message(streamed_answer)
         logging.info(f"Answer: {streamed_answer}")
         logging.info(f"Answer: {streamed_answer}")
-          
 
 
     def dry_run(self, input_query, config: QueryConfig = None):
     def dry_run(self, input_query, config: QueryConfig = None):
         """
         """
@@ -242,7 +256,8 @@ class EmbedChain:
         the `max_tokens` parameter.
         the `max_tokens` parameter.
 
 
         :param input_query: The query to use.
         :param input_query: The query to use.
-        :param config: Optional. The `QueryConfig` instance to use as configuration options.
+        :param config: Optional. The `QueryConfig` instance to use as
+        configuration options.
         :return: The prompt that would be sent to the LLM
         :return: The prompt that would be sent to the LLM
         """
         """
         if config is None:
         if config is None:
@@ -260,7 +275,6 @@ class EmbedChain:
         """
         """
         return self.collection.count()
         return self.collection.count()
 
 
-
     def reset(self):
     def reset(self):
         """
         """
         Resets the database. Deletes all embeddings irreversibly.
         Resets the database. Deletes all embeddings irreversibly.
@@ -288,35 +302,31 @@ class App(EmbedChain):
         super().__init__(config)
         super().__init__(config)
 
 
     def get_llm_model_answer(self, prompt, config: ChatConfig):
     def get_llm_model_answer(self, prompt, config: ChatConfig):
-
         messages = []
         messages = []
-        messages.append({
-            "role": "user", "content": prompt
-        })
+        messages.append({"role": "user", "content": prompt})
         response = openai.ChatCompletion.create(
         response = openai.ChatCompletion.create(
             model="gpt-3.5-turbo-0613",
             model="gpt-3.5-turbo-0613",
             messages=messages,
             messages=messages,
             temperature=0,
             temperature=0,
             max_tokens=1000,
             max_tokens=1000,
             top_p=1,
             top_p=1,
-            stream=config.stream
+            stream=config.stream,
         )
         )
 
 
         if config.stream:
         if config.stream:
             return self._stream_llm_model_response(response)
             return self._stream_llm_model_response(response)
         else:
         else:
             return response["choices"][0]["message"]["content"]
             return response["choices"][0]["message"]["content"]
-    
+
     def _stream_llm_model_response(self, response):
     def _stream_llm_model_response(self, response):
         """
         """
         This is a generator for streaming response from the OpenAI completions API
         This is a generator for streaming response from the OpenAI completions API
         """
         """
         for line in response:
         for line in response:
-            chunk = line['choices'][0].get('delta', {}).get('content', '')
+            chunk = line["choices"][0].get("delta", {}).get("content", "")
             yield chunk
             yield chunk
 
 
 
 
-
 class OpenSourceApp(EmbedChain):
 class OpenSourceApp(EmbedChain):
     """
     """
     The OpenSource app.
     The OpenSource app.
@@ -330,20 +340,24 @@ class OpenSourceApp(EmbedChain):
 
 
     def __init__(self, config: InitConfig = None):
     def __init__(self, config: InitConfig = None):
         """
         """
-        :param config: InitConfig instance to load as configuration. Optional. `ef` defaults to open source.
+        :param config: InitConfig instance to load as configuration. Optional.
+        `ef` defaults to open source.
         """
         """
-        print("Loading open source embedding model. This may take some time...")
+        print(
+            "Loading open source embedding model. This may take some time..."
+        )  # noqa:E501
         if not config:
         if not config:
             config = InitConfig(
             config = InitConfig(
-                ef = embedding_functions.SentenceTransformerEmbeddingFunction(
+                ef=embedding_functions.SentenceTransformerEmbeddingFunction(
                     model_name="all-MiniLM-L6-v2"
                     model_name="all-MiniLM-L6-v2"
                 )
                 )
             )
             )
         elif not config.ef:
         elif not config.ef:
             config._set_embedding_function(
             config._set_embedding_function(
-                    embedding_functions.SentenceTransformerEmbeddingFunction(
-                model_name="all-MiniLM-L6-v2"
-            ))
+                embedding_functions.SentenceTransformerEmbeddingFunction(
+                    model_name="all-MiniLM-L6-v2"
+                )
+            )
         print("Successfully loaded open source embedding model.")
         print("Successfully loaded open source embedding model.")
         super().__init__(config)
         super().__init__(config)
 
 
@@ -353,10 +367,7 @@ class OpenSourceApp(EmbedChain):
         global gpt4all_model
         global gpt4all_model
         if gpt4all_model is None:
         if gpt4all_model is None:
             gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
             gpt4all_model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
-        response = gpt4all_model.generate(
-            prompt=prompt,
-            streaming=config.stream
-        )
+        response = gpt4all_model.generate(prompt=prompt, streaming=config.stream)
         return response
         return response
 
 
 
 
@@ -368,12 +379,11 @@ class EmbedChainPersonApp:
     :param person: name of the person, better if its a well known person.
     :param person: name of the person, better if its a well known person.
     :param config: InitConfig instance to load as configuration.
     :param config: InitConfig instance to load as configuration.
     """
     """
+
     def __init__(self, person, config: InitConfig = None):
     def __init__(self, person, config: InitConfig = None):
         self.person = person
         self.person = person
-        self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style."
-        self.template = Template(
-            self.person_prompt + " " + DEFAULT_PROMPT
-        )
+        self.person_prompt = f"You are {person}. Whatever you say, you will always say in {person} style."  # noqa:E501
+        self.template = Template(self.person_prompt + " " + DEFAULT_PROMPT)
         if config is None:
         if config is None:
             config = InitConfig()
             config = InitConfig()
         super().__init__(config)
         super().__init__(config)
@@ -384,6 +394,7 @@ class PersonApp(EmbedChainPersonApp, App):
     The Person app.
     The Person app.
     Extends functionality from EmbedChainPersonApp and App
     Extends functionality from EmbedChainPersonApp and App
     """
     """
+
     def query(self, input_query, config: QueryConfig = None):
     def query(self, input_query, config: QueryConfig = None):
         query_config = QueryConfig(
         query_config = QueryConfig(
             template=self.template,
             template=self.template,
@@ -392,7 +403,7 @@ class PersonApp(EmbedChainPersonApp, App):
 
 
     def chat(self, input_query, config: ChatConfig = None):
     def chat(self, input_query, config: ChatConfig = None):
         chat_config = ChatConfig(
         chat_config = ChatConfig(
-            template = self.template,
+            template=self.template,
         )
         )
         return super().chat(input_query, chat_config)
         return super().chat(input_query, chat_config)
 
 
@@ -402,6 +413,7 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
     The Person app.
     The Person app.
     Extends functionality from EmbedChainPersonApp and OpenSourceApp
     Extends functionality from EmbedChainPersonApp and OpenSourceApp
     """
     """
+
     def query(self, input_query, config: QueryConfig = None):
     def query(self, input_query, config: QueryConfig = None):
         query_config = QueryConfig(
         query_config = QueryConfig(
             template=self.template,
             template=self.template,
@@ -410,6 +422,6 @@ class PersonOpenSourceApp(EmbedChainPersonApp, OpenSourceApp):
 
 
     def chat(self, input_query, config: ChatConfig = None):
     def chat(self, input_query, config: ChatConfig = None):
         chat_config = ChatConfig(
         chat_config = ChatConfig(
-            template = self.template,
+            template=self.template,
         )
         )
-        return super().chat(input_query, chat_config)
+        return super().chat(input_query, chat_config)

+ 2 - 1
embedchain/loaders/docx_file.py

@@ -1,8 +1,9 @@
 from langchain.document_loaders import Docx2txtLoader
 from langchain.document_loaders import Docx2txtLoader
 
 
+
 class DocxFileLoader:
 class DocxFileLoader:
     def load_data(self, url):
     def load_data(self, url):
-        ''' Load data from a .docx file. '''
+        """Load data from a .docx file."""
         loader = Docx2txtLoader(url)
         loader = Docx2txtLoader(url)
         output = []
         output = []
         data = loader.load()
         data = loader.load()

+ 7 - 6
embedchain/loaders/local_qna_pair.py

@@ -1,13 +1,14 @@
 class LocalQnaPairLoader:
 class LocalQnaPairLoader:
-
     def load_data(self, content):
     def load_data(self, content):
-        ''' Load data from a local QnA pair. '''
+        """Load data from a local QnA pair."""
         question, answer = content
         question, answer = content
         content = f"Q: {question}\nA: {answer}"
         content = f"Q: {question}\nA: {answer}"
         meta_data = {
         meta_data = {
             "url": "local",
             "url": "local",
         }
         }
-        return [{
-            "content": content,
-            "meta_data": meta_data,
-        }]
+        return [
+            {
+                "content": content,
+                "meta_data": meta_data,
+            }
+        ]

+ 7 - 6
embedchain/loaders/local_text.py

@@ -1,11 +1,12 @@
 class LocalTextLoader:
 class LocalTextLoader:
-
     def load_data(self, content):
     def load_data(self, content):
-        ''' Load data from a local text file. '''
+        """Load data from a local text file."""
         meta_data = {
         meta_data = {
             "url": "local",
             "url": "local",
         }
         }
-        return [{
-            "content": content,
-            "meta_data": meta_data,
-        }]
+        return [
+            {
+                "content": content,
+                "meta_data": meta_data,
+            }
+        ]

+ 7 - 6
embedchain/loaders/pdf_file.py

@@ -4,9 +4,8 @@ from embedchain.utils import clean_string
 
 
 
 
 class PdfFileLoader:
 class PdfFileLoader:
-    
     def load_data(self, url):
     def load_data(self, url):
-        ''' Load data from a PDF file. '''
+        """Load data from a PDF file."""
         loader = PyPDFLoader(url)
         loader = PyPDFLoader(url)
         output = []
         output = []
         pages = loader.load_and_split()
         pages = loader.load_and_split()
@@ -17,8 +16,10 @@ class PdfFileLoader:
             content = clean_string(content)
             content = clean_string(content)
             meta_data = page.metadata
             meta_data = page.metadata
             meta_data["url"] = url
             meta_data["url"] = url
-            output.append({
-                "content": content,
-                "meta_data": meta_data,
-            })
+            output.append(
+                {
+                    "content": content,
+                    "meta_data": meta_data,
+                }
+            )
         return output
         return output

+ 23 - 14
embedchain/loaders/web_page.py

@@ -1,22 +1,29 @@
 import requests
 import requests
-
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
 
 
 from embedchain.utils import clean_string
 from embedchain.utils import clean_string
 
 
 
 
 class WebPageLoader:
 class WebPageLoader:
-
     def load_data(self, url):
     def load_data(self, url):
-        ''' Load data from a web page. '''
+        """Load data from a web page."""
         response = requests.get(url)
         response = requests.get(url)
         data = response.content
         data = response.content
-        soup = BeautifulSoup(data, 'html.parser')
-        for tag in soup([
-            "nav", "aside", "form", "header",
-            "noscript", "svg", "canvas",
-            "footer", "script", "style"
-        ]):
+        soup = BeautifulSoup(data, "html.parser")
+        for tag in soup(
+            [
+                "nav",
+                "aside",
+                "form",
+                "header",
+                "noscript",
+                "svg",
+                "canvas",
+                "footer",
+                "script",
+                "style",
+            ]
+        ):
             tag.string = " "
             tag.string = " "
         output = []
         output = []
         content = soup.get_text()
         content = soup.get_text()
@@ -24,8 +31,10 @@ class WebPageLoader:
         meta_data = {
         meta_data = {
             "url": url,
             "url": url,
         }
         }
-        output.append({
-            "content": content,
-            "meta_data": meta_data,
-        })
-        return output
+        output.append(
+            {
+                "content": content,
+                "meta_data": meta_data,
+            }
+        )
+        return output

+ 7 - 6
embedchain/loaders/youtube_video.py

@@ -4,9 +4,8 @@ from embedchain.utils import clean_string
 
 
 
 
 class YoutubeVideoLoader:
 class YoutubeVideoLoader:
-
     def load_data(self, url):
     def load_data(self, url):
-        ''' Load data from a Youtube video. '''
+        """Load data from a Youtube video."""
         loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)
         loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)
         doc = loader.load()
         doc = loader.load()
         output = []
         output = []
@@ -16,8 +15,10 @@ class YoutubeVideoLoader:
         content = clean_string(content)
         content = clean_string(content)
         meta_data = doc[0].metadata
         meta_data = doc[0].metadata
         meta_data["url"] = url
         meta_data["url"] = url
-        output.append({
-            "content": content,
-            "meta_data": meta_data,
-        })
+        output.append(
+            {
+                "content": content,
+                "meta_data": meta_data,
+            }
+        )
         return output
         return output

+ 12 - 9
embedchain/utils.py

@@ -3,30 +3,33 @@ import re
 
 
 def clean_string(text):
 def clean_string(text):
     """
     """
-    This function takes in a string and performs a series of text cleaning operations. 
+    This function takes in a string and performs a series of text cleaning operations.
 
 
     Args:
     Args:
         text (str): The text to be cleaned. This is expected to be a string.
         text (str): The text to be cleaned. This is expected to be a string.
 
 
     Returns:
     Returns:
-        cleaned_text (str): The cleaned text after all the cleaning operations have been performed.
+        cleaned_text (str): The cleaned text after all the cleaning operations
+        have been performed.
     """
     """
     # Replacement of newline characters:
     # Replacement of newline characters:
-    text = text.replace('\n', ' ')
+    text = text.replace("\n", " ")
 
 
     # Stripping and reducing multiple spaces to single:
     # Stripping and reducing multiple spaces to single:
-    cleaned_text = re.sub(r'\s+', ' ', text.strip())
+    cleaned_text = re.sub(r"\s+", " ", text.strip())
 
 
     # Removing backslashes:
     # Removing backslashes:
-    cleaned_text = cleaned_text.replace('\\', '')
+    cleaned_text = cleaned_text.replace("\\", "")
 
 
     # Replacing hash characters:
     # Replacing hash characters:
-    cleaned_text = cleaned_text.replace('#', ' ')
+    cleaned_text = cleaned_text.replace("#", " ")
 
 
     # Eliminating consecutive non-alphanumeric characters:
     # Eliminating consecutive non-alphanumeric characters:
-    # This regex identifies consecutive non-alphanumeric characters (i.e., not a word character [a-zA-Z0-9_] and not a whitespace) in the string 
-    # and replaces each group of such characters with a single occurrence of that character. 
+    # This regex identifies consecutive non-alphanumeric characters (i.e., not
+    # a word character [a-zA-Z0-9_] and not a whitespace) in the string
+    # and replaces each group of such characters with a single occurrence of
+    # that character.
     # For example, "!!! hello !!!" would become "! hello !".
     # For example, "!!! hello !!!" would become "! hello !".
-    cleaned_text = re.sub(r'([^\w\s])\1*', r'\1', cleaned_text)
+    cleaned_text = re.sub(r"([^\w\s])\1*", r"\1", cleaned_text)
 
 
     return cleaned_text
     return cleaned_text

+ 2 - 2
embedchain/vectordb/base_vector_db.py

@@ -1,12 +1,12 @@
 class BaseVectorDB:
 class BaseVectorDB:
-    ''' Base class for vector database. '''
+    """Base class for vector database."""
 
 
     def __init__(self):
     def __init__(self):
         self.client = self._get_or_create_db()
         self.client = self._get_or_create_db()
         self.collection = self._get_or_create_collection()
         self.collection = self._get_or_create_collection()
 
 
     def _get_or_create_db(self):
     def _get_or_create_db(self):
-        ''' Get or create the database. '''
+        """Get or create the database."""
         raise NotImplementedError
         raise NotImplementedError
 
 
     def _get_or_create_collection(self):
     def _get_or_create_collection(self):

+ 9 - 8
embedchain/vectordb/chroma_db.py

@@ -1,14 +1,14 @@
-import chromadb
 import os
 import os
 
 
+import chromadb
 from chromadb.utils import embedding_functions
 from chromadb.utils import embedding_functions
 
 
 from embedchain.vectordb.base_vector_db import BaseVectorDB
 from embedchain.vectordb.base_vector_db import BaseVectorDB
 
 
 
 
 class ChromaDB(BaseVectorDB):
 class ChromaDB(BaseVectorDB):
-    ''' Vector database using ChromaDB. '''
-    
+    """Vector database using ChromaDB."""
+
     def __init__(self, db_dir=None, ef=None):
     def __init__(self, db_dir=None, ef=None):
         if ef:
         if ef:
             self.ef = ef
             self.ef = ef
@@ -16,23 +16,24 @@ class ChromaDB(BaseVectorDB):
             self.ef = embedding_functions.OpenAIEmbeddingFunction(
             self.ef = embedding_functions.OpenAIEmbeddingFunction(
                 api_key=os.getenv("OPENAI_API_KEY"),
                 api_key=os.getenv("OPENAI_API_KEY"),
                 organization_id=os.getenv("OPENAI_ORGANIZATION"),
                 organization_id=os.getenv("OPENAI_ORGANIZATION"),
-                model_name="text-embedding-ada-002"
+                model_name="text-embedding-ada-002",
             )
             )
         if db_dir is None:
         if db_dir is None:
             db_dir = "db"
             db_dir = "db"
         self.client_settings = chromadb.config.Settings(
         self.client_settings = chromadb.config.Settings(
             chroma_db_impl="duckdb+parquet",
             chroma_db_impl="duckdb+parquet",
             persist_directory=db_dir,
             persist_directory=db_dir,
-            anonymized_telemetry=False
+            anonymized_telemetry=False,
         )
         )
         super().__init__()
         super().__init__()
 
 
     def _get_or_create_db(self):
     def _get_or_create_db(self):
-        ''' Get or create the database. '''
+        """Get or create the database."""
         return chromadb.Client(self.client_settings)
         return chromadb.Client(self.client_settings)
 
 
     def _get_or_create_collection(self):
     def _get_or_create_collection(self):
-        ''' Get or create the collection. '''
+        """Get or create the collection."""
         return self.client.get_or_create_collection(
         return self.client.get_or_create_collection(
-            'embedchain_store', embedding_function=self.ef,
+            "embedchain_store",
+            embedding_function=self.ef,
         )
         )

+ 68 - 0
pyproject.toml

@@ -0,0 +1,68 @@
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[tool.ruff]
+select = ["E", "F"]
+ignore = []
+fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
+unfixable = []
+exclude = [
+    ".bzr",
+    ".direnv",
+    ".eggs",
+    ".git",
+    ".git-rewrite",
+    ".hg",
+    ".mypy_cache",
+    ".nox",
+    ".pants.d",
+    ".pytype",
+    ".ruff_cache",
+    ".svn",
+    ".tox",
+    ".venv",
+    "__pypackages__",
+    "_build",
+    "buck-out",
+    "build",
+    "dist",
+    "node_modules",
+    "venv",
+]
+line-length = 88
+dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+target-version = "py38"
+
+[tool.ruff.mccabe]
+max-complexity = 10
+
+[tool.black]
+line-length = 88
+target-version = ["py38", "py39", "py310", "py311"]
+include = '\.pyi?$'
+exclude = '''
+/(
+    \.eggs
+  | \.git
+  | \.hg
+  | \.mypy_cache
+  | \.nox
+  | \.pants.d
+  | \.pytype
+  | \.ruff_cache
+  | \.svn
+  | \.tox
+  | \.venv
+  | __pypackages__
+  | _build
+  | buck-out
+  | build
+  | dist
+  | node_modules
+  | venv
+)/
+'''
+
+[tool.black.format]
+color = true

+ 5 - 0
requirements/dev.txt

@@ -0,0 +1,5 @@
+pip
+black==23.3.0
+isort==5.8.0
+ruff==0.0.277
+pytest==7.4.0

+ 2 - 2
setup.py

@@ -8,7 +8,7 @@ setuptools.setup(
     version="0.0.18",
     version="0.0.18",
     author="Taranjeet Singh",
     author="Taranjeet Singh",
     author_email="reachtotj@gmail.com",
     author_email="reachtotj@gmail.com",
-    description="embedchain is a framework to easily create LLM powered bots over any dataset",
+    description="embedchain is a framework to easily create LLM powered bots over any dataset",  # noqa:E501
     long_description=long_description,
     long_description=long_description,
     long_description_content_type="text/markdown",
     long_description_content_type="text/markdown",
     url="https://github.com/embedchain/embedchain",
     url="https://github.com/embedchain/embedchain",
@@ -18,7 +18,7 @@ setuptools.setup(
         "License :: OSI Approved :: Apache Software License",
         "License :: OSI Approved :: Apache Software License",
         "Operating System :: OS Independent",
         "Operating System :: OS Independent",
     ],
     ],
-    python_requires='>=3.8',
+    python_requires=">=3.8",
     py_modules=["embedchain"],
     py_modules=["embedchain"],
     install_requires=[
     install_requires=[
         "langchain>=0.0.205",
         "langchain>=0.0.205",

+ 2 - 2
tests/test_embedchain.py

@@ -1,7 +1,7 @@
 import os
 import os
-
 import unittest
 import unittest
-from unittest.mock import patch, MagicMock
+from unittest.mock import MagicMock, patch
+
 from embedchain import App
 from embedchain import App