Explorar el Código

chore: linting (#597)

cachho hace 1 año
padre
commit
03146946fa

+ 10 - 12
embedchain/embedchain.py

@@ -242,7 +242,7 @@ class EmbedChain(JSONSerializable):
         src: Any,
         metadata: Optional[Dict[str, Any]] = None,
         source_id: Optional[str] = None,
-        dry_run = False
+        dry_run=False,
     ) -> Tuple[List[str], Dict[str, Any], List[str], int]:
         """The loader to use to load the data.
 
@@ -320,14 +320,14 @@ class EmbedChain(JSONSerializable):
         return list(documents), metadatas, ids, count_new_chunks
 
     def load_and_embed_v2(
-            self,
-            loader: BaseLoader,
-            chunker: BaseChunker,
-            src: Any,
-            metadata: Optional[Dict[str, Any]] = None,
-            source_id: Optional[str] = None,
-            dry_run = False
-        ):
+        self,
+        loader: BaseLoader,
+        chunker: BaseChunker,
+        src: Any,
+        metadata: Optional[Dict[str, Any]] = None,
+        source_id: Optional[str] = None,
+        dry_run=False,
+    ):
         """
         Loads the data from the given URL, chunks it, and adds it to database.
 
@@ -364,9 +364,7 @@ class EmbedChain(JSONSerializable):
         # this means that doc content has changed.
         if existing_doc_id and existing_doc_id != new_doc_id:
             print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
-            self.db.delete({
-                "doc_id": existing_doc_id
-            })
+            self.db.delete({"doc_id": existing_doc_id})
 
         # get existing ids, and discard doc if any common id exist.
         where = {"app_id": self.config.id} if self.config.id is not None else {}

+ 1 - 4
embedchain/loaders/csv.py

@@ -46,7 +46,4 @@ class CsvLoader(BaseLoader):
                 lines.append(line)
                 result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
         doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
-        return {
-            "doc_id": doc_id,
-            "data": result
-        }
+        return {"doc_id": doc_id, "data": result}

+ 1 - 1
embedchain/loaders/local_qna_pair.py

@@ -22,5 +22,5 @@ class LocalQnaPairLoader(BaseLoader):
                     "content": content,
                     "meta_data": meta_data,
                 }
-            ]
+            ],
         }

+ 1 - 1
embedchain/loaders/local_text.py

@@ -20,5 +20,5 @@ class LocalTextLoader(BaseLoader):
                     "content": content,
                     "meta_data": meta_data,
                 }
-            ]
+            ],
         }

+ 5 - 5
embedchain/loaders/notion.py

@@ -39,9 +39,9 @@ class NotionLoader(BaseLoader):
         return {
             "doc_id": doc_id,
             "data": [
-            {
-                "content": text,
-                "meta_data": {"url": f"notion-{formatted_id}"},
-            }
-        ],
+                {
+                    "content": text,
+                    "meta_data": {"url": f"notion-{formatted_id}"},
+                }
+            ],
         }

+ 1 - 4
embedchain/loaders/sitemap.py

@@ -43,7 +43,4 @@ class SitemapLoader(BaseLoader):
                     logging.warning(f"Page is not readable (too many invalid characters): {link}")
             except ParserRejectedMarkup as e:
                 logging.error(f"Failed to parse {link}: {e}")
-        return {
-            "doc_id": doc_id,
-            "data": [data[0] for data in output]
-        }
+        return {"doc_id": doc_id, "data": [data[0] for data in output]}

+ 1 - 1
embedchain/loaders/web_page.py

@@ -66,7 +66,7 @@ class WebPageLoader(BaseLoader):
         }
         content = content
         doc_id = hashlib.sha256((content + url).encode()).hexdigest()
-        return  {
+        return {
             "doc_id": doc_id,
             "data": [
                 {

+ 1 - 1
embedchain/vectordb/base_vector_db.py

@@ -47,4 +47,4 @@ class BaseVectorDB(JSONSerializable):
         raise NotImplementedError
 
     def set_collection_name(self, name: str):
-        raise NotImplementedError
+        raise NotImplementedError

+ 2 - 4
embedchain/vectordb/chroma.py

@@ -1,5 +1,5 @@
 import logging
-from typing import Dict, List, Optional, Any
+from typing import Any, Dict, List, Optional
 
 from chromadb import Collection, QueryResult
 from langchain.docstore.document import Document
@@ -105,9 +105,7 @@ class ChromaDB(BaseVectorDB):
             args["where"] = where
         if limit:
             args["limit"] = limit
-        return self.collection.get(
-            **args
-        )
+        return self.collection.get(**args)
 
     def get_advanced(self, where):
         return self.collection.get(where=where, limit=1)

+ 1 - 1
tests/chunkers/test_text.py

@@ -76,5 +76,5 @@ class MockLoader:
                     "content": src,
                     "meta_data": {"url": "none"},
                 }
-            ]
+            ],
         }

+ 1 - 1
tests/embedchain/test_add.py

@@ -3,7 +3,7 @@ import unittest
 from unittest.mock import MagicMock, patch
 
 from embedchain import App
-from embedchain.config import AppConfig, AddConfig, ChunkerConfig
+from embedchain.config import AddConfig, AppConfig, ChunkerConfig
 from embedchain.models.data_type import DataType