Explorar o código

feat: Update line length to 120 chars (#278)

Deshraj Yadav %!s(int64=2) %!d(string=hai) anos
pai
achega
fd97fb268a

+ 1 - 1
embedchain/chunkers/base_chunker.py

@@ -46,4 +46,4 @@ class BaseChunker:
 
         Override in child class if custom logic.
         """
-        return self.text_splitter.split_text(content)
+        return self.text_splitter.split_text(content)

+ 1 - 1
embedchain/chunkers/code_docs_page.py

@@ -19,4 +19,4 @@ class CodeDocsPageChunker(BaseChunker):
         if config is None:
             config = TEXT_SPLITTER_CHUNK_PARAMS
         text_splitter = RecursiveCharacterTextSplitter(**config)
-        super().__init__(text_splitter)
+        super().__init__(text_splitter)

+ 3 - 10
embedchain/config/InitConfig.py

@@ -40,13 +40,8 @@ class InitConfig(BaseConfig):
         :raises ValueError: If the template is not valid as template should contain
         $context and $query
         """
-        if (
-            os.getenv("OPENAI_API_KEY") is None
-            and os.getenv("OPENAI_ORGANIZATION") is None
-        ):
-            raise ValueError(
-                "OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided"  # noqa:E501
-            )
+        if os.getenv("OPENAI_API_KEY") is None and os.getenv("OPENAI_ORGANIZATION") is None:
+            raise ValueError("OPENAI_API_KEY or OPENAI_ORGANIZATION environment variables not provided")  # noqa:E501
         self.ef = embedding_functions.OpenAIEmbeddingFunction(
             api_key=os.getenv("OPENAI_API_KEY"),
             organization_id=os.getenv("OPENAI_ORGANIZATION"),
@@ -74,8 +69,6 @@ class InitConfig(BaseConfig):
             if not isinstance(level, int):
                 raise ValueError(f"Invalid log level: {debug_level}")
 
-        logging.basicConfig(
-            format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level
-        )
+        logging.basicConfig(format="%(asctime)s [%(name)s] [%(levelname)s] %(message)s", level=level)
         self.logger = logging.getLogger(__name__)
         return

+ 2 - 6
embedchain/config/QueryConfig.py

@@ -113,9 +113,7 @@ class QueryConfig(BaseConfig):
             if self.history is None:
                 raise ValueError("`template` should have `query` and `context` keys")
             else:
-                raise ValueError(
-                    "`template` should have `query`, `context` and `history` keys"
-                )
+                raise ValueError("`template` should have `query`, `context` and `history` keys")
 
         if not isinstance(stream, bool):
             raise ValueError("`stream` should be bool")
@@ -129,9 +127,7 @@ class QueryConfig(BaseConfig):
         :return: Boolean, valid (true) or invalid (false)
         """
         if self.history is None:
-            return re.search(query_re, template.template) and re.search(
-                context_re, template.template
-            )
+            return re.search(query_re, template.template) and re.search(context_re, template.template)
         else:
             return (
                 re.search(query_re, template.template)

+ 1 - 1
embedchain/data_formatter/data_formatter.py

@@ -66,7 +66,7 @@ class DataFormatter:
             "text": TextChunker(config),
             "docx": DocxFileChunker(config),
             "sitemap": WebPageChunker(config),
-            "code_docs_page": CodeDocsPageChunker(config)
+            "code_docs_page": CodeDocsPageChunker(config),
         }
         if data_type in chunkers:
             return chunkers[data_type]

+ 11 - 32
embedchain/embedchain.py

@@ -9,7 +9,7 @@ from langchain.docstore.document import Document
 from langchain.memory import ConversationBufferMemory
 
 from embedchain.config import AddConfig, ChatConfig, InitConfig, QueryConfig
-from embedchain.config.QueryConfig import DEFAULT_PROMPT, CODE_DOCS_PAGE_PROMPT_TEMPLATE
+from embedchain.config.QueryConfig import CODE_DOCS_PAGE_PROMPT_TEMPLATE, DEFAULT_PROMPT
 from embedchain.data_formatter import DataFormatter
 
 gpt4all_model = None
@@ -54,10 +54,8 @@ class EmbedChain:
 
         data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([data_type, url, metadata])
-        self.load_and_embed(
-            data_formatter.loader, data_formatter.chunker, url, metadata
-        )
-        if data_type in ("code_docs_page", ):
+        self.load_and_embed(data_formatter.loader, data_formatter.chunker, url, metadata)
+        if data_type in ("code_docs_page",):
             self.is_code_docs_instance = True
 
     def add_local(self, data_type, content, metadata=None, config: AddConfig = None):
@@ -106,12 +104,8 @@ class EmbedChain:
         existing_ids = set(existing_docs["ids"])
 
         if len(existing_ids):
-            data_dict = {
-                id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)
-            }
-            data_dict = {
-                id: value for id, value in data_dict.items() if id not in existing_ids
-            }
+            data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
+            data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}
 
             if not data_dict:
                 print(f"All data from {src} already exists in the database.")
@@ -125,15 +119,8 @@ class EmbedChain:
         # Add metadata to each document
         metadatas_with_metadata = [meta or metadata for meta in metadatas]
 
-        self.collection.add(
-            documents=documents, metadatas=list(metadatas_with_metadata), ids=ids
-        )
-        print(
-            (
-                f"Successfully saved {src}. New chunks count: "
-                f"{self.count() - chunks_before_addition}"
-            )
-        )
+        self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids)
+        print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}"))
 
     def _format_result(self, results):
         return [
@@ -180,13 +167,9 @@ class EmbedChain:
         """
         context_string = (" | ").join(contexts)
         if not config.history:
-            prompt = config.template.substitute(
-                context=context_string, query=input_query
-            )
+            prompt = config.template.substitute(context=context_string, query=input_query)
         else:
-            prompt = config.template.substitute(
-                context=context_string, query=input_query, history=config.history
-            )
+            prompt = config.template.substitute(context=context_string, query=input_query, history=config.history)
         return prompt
 
     def get_answer_from_llm(self, prompt, config: ChatConfig):
@@ -387,17 +370,13 @@ class OpenSourceApp(EmbedChain):
         :param config: InitConfig instance to load as configuration. Optional.
         `ef` defaults to open source.
         """
-        print(
-            "Loading open source embedding model. This may take some time..."
-        )  # noqa:E501
+        print("Loading open source embedding model. This may take some time...")  # noqa:E501
         if not config:
             config = InitConfig()
 
         if not config.ef:
             config._set_embedding_function(
-                embedding_functions.SentenceTransformerEmbeddingFunction(
-                    model_name="all-MiniLM-L6-v2"
-                )
+                embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-MiniLM-L6-v2")
             )
 
         if not config.db:

+ 10 - 9
embedchain/loaders/code_docs_page.py

@@ -3,6 +3,7 @@ from bs4 import BeautifulSoup
 
 from embedchain.utils import clean_string
 
+
 class CodeDocsPageLoader:
     def load_data(self, url):
         """Load data from a web page."""
@@ -10,14 +11,14 @@ class CodeDocsPageLoader:
         data = response.content
         soup = BeautifulSoup(data, "html.parser")
         selectors = [
-            'article.bd-article',
+            "article.bd-article",
             'article[role="main"]',
-            'div.md-content',
+            "div.md-content",
             'div[role="main"]',
-            'div.container',
-            'div.section',
-            'article',
-            'main',
+            "div.container",
+            "div.section",
+            "article",
+            "main",
         ]
         content = None
         for selector in selectors:
@@ -43,11 +44,11 @@ class CodeDocsPageLoader:
             ]
         ):
             tag.string = " "
-        for div in soup.find_all("div", {'class': 'cell_output'}):
+        for div in soup.find_all("div", {"class": "cell_output"}):
             div.decompose()
-        for div in soup.find_all("div", {'class': 'output_wrapper'}):
+        for div in soup.find_all("div", {"class": "output_wrapper"}):
             div.decompose()
-        for div in soup.find_all("div", {'class': 'output'}):
+        for div in soup.find_all("div", {"class": "output"}):
             div.decompose()
         content = clean_string(soup.get_text())
         output = []

+ 2 - 2
pyproject.toml

@@ -30,7 +30,7 @@ exclude = [
     "node_modules",
     "venv",
 ]
-line-length = 88
+line-length = 120
 dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
 target-version = "py38"
 
@@ -38,7 +38,7 @@ target-version = "py38"
 max-complexity = 10
 
 [tool.black]
-line-length = 88
+line-length = 120
 target-version = ["py38", "py39", "py310", "py311"]
 include = '\.pyi?$'
 exclude = '''

+ 0 - 1
setup.py

@@ -1,6 +1,5 @@
 import setuptools
 
-
 with open("README.md", "r", encoding="utf-8") as fh:
     long_description = fh.read()