Browse Source

refactor: loader chunker typing (#324)

cachho 2 years ago
parent
commit
55bfd7cafe

+ 3 - 1
embedchain/embedchain.py

@@ -6,10 +6,12 @@ from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
 from langchain.memory import ConversationBufferMemory
 from langchain.memory import ConversationBufferMemory
 
 
+from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig, ChatConfig, QueryConfig
 from embedchain.config import AddConfig, ChatConfig, QueryConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
 from embedchain.config.apps.BaseAppConfig import BaseAppConfig
 from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
 from embedchain.config.QueryConfig import DOCS_SITE_PROMPT_TEMPLATE
 from embedchain.data_formatter import DataFormatter
 from embedchain.data_formatter import DataFormatter
+from embedchain.loaders.base_loader import BaseLoader
 
 
 load_dotenv()
 load_dotenv()
 
 
@@ -80,7 +82,7 @@ class EmbedChain:
             metadata,
             metadata,
         )
         )
 
 
-    def load_and_embed(self, loader, chunker, src, metadata=None):
+    def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata=None):
         """
         """
         Loads the data from the given URL, chunks it, and adds it to database.
         Loads the data from the given URL, chunks it, and adds it to database.
 
 

+ 9 - 0
embedchain/loaders/base_loader.py

@@ -0,0 +1,9 @@
+class BaseLoader:
+    def __init__(self):
+        pass
+
+    def load_data():
+        """
+        Implemented by child classes
+        """
+        pass

+ 3 - 1
embedchain/loaders/docs_site_loader.py

@@ -4,8 +4,10 @@ from urllib.parse import urljoin, urlparse
 import requests
 import requests
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
 
 
+from embedchain.loaders.base_loader import BaseLoader
 
 
-class DocsSiteLoader:
+
+class DocsSiteLoader(BaseLoader):
     def __init__(self):
     def __init__(self):
         self.visited_links = set()
         self.visited_links = set()
 
 

+ 3 - 1
embedchain/loaders/docx_file.py

@@ -1,7 +1,9 @@
 from langchain.document_loaders import Docx2txtLoader
 from langchain.document_loaders import Docx2txtLoader
 
 
+from embedchain.loaders.base_loader import BaseLoader
 
 
-class DocxFileLoader:
+
+class DocxFileLoader(BaseLoader):
     def load_data(self, url):
     def load_data(self, url):
         """Load data from a .docx file."""
         """Load data from a .docx file."""
         loader = Docx2txtLoader(url)
         loader = Docx2txtLoader(url)

+ 4 - 1
embedchain/loaders/local_qna_pair.py

@@ -1,4 +1,7 @@
-class LocalQnaPairLoader:
+from embedchain.loaders.base_loader import BaseLoader
+
+
+class LocalQnaPairLoader(BaseLoader):
     def load_data(self, content):
     def load_data(self, content):
         """Load data from a local QnA pair."""
         """Load data from a local QnA pair."""
         question, answer = content
         question, answer = content

+ 4 - 1
embedchain/loaders/local_text.py

@@ -1,4 +1,7 @@
-class LocalTextLoader:
+from embedchain.loaders.base_loader import BaseLoader
+
+
+class LocalTextLoader(BaseLoader):
     def load_data(self, content):
     def load_data(self, content):
         """Load data from a local text file."""
         """Load data from a local text file."""
         meta_data = {
         meta_data = {

+ 2 - 1
embedchain/loaders/pdf_file.py

@@ -1,9 +1,10 @@
 from langchain.document_loaders import PyPDFLoader
 from langchain.document_loaders import PyPDFLoader
 
 
+from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 from embedchain.utils import clean_string
 
 
 
 
-class PdfFileLoader:
+class PdfFileLoader(BaseLoader):
     def load_data(self, url):
     def load_data(self, url):
         """Load data from a PDF file."""
         """Load data from a PDF file."""
         loader = PyPDFLoader(url)
         loader = PyPDFLoader(url)

+ 2 - 1
embedchain/loaders/sitemap.py

@@ -4,11 +4,12 @@ import requests
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
 from bs4.builder import ParserRejectedMarkup
 from bs4.builder import ParserRejectedMarkup
 
 
+from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.web_page import WebPageLoader
 from embedchain.loaders.web_page import WebPageLoader
 from embedchain.utils import is_readable
 from embedchain.utils import is_readable
 
 
 
 
-class SitemapLoader:
+class SitemapLoader(BaseLoader):
     def load_data(self, sitemap_url):
     def load_data(self, sitemap_url):
         """
         """
         This method takes a sitemap URL as input and retrieves
         This method takes a sitemap URL as input and retrieves

+ 2 - 1
embedchain/loaders/web_page.py

@@ -3,10 +3,11 @@ import logging
 import requests
 import requests
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
 
 
+from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 from embedchain.utils import clean_string
 
 
 
 
-class WebPageLoader:
+class WebPageLoader(BaseLoader):
     def load_data(self, url):
     def load_data(self, url):
         """Load data from a web page."""
         """Load data from a web page."""
         response = requests.get(url)
         response = requests.get(url)

+ 2 - 1
embedchain/loaders/youtube_video.py

@@ -1,9 +1,10 @@
 from langchain.document_loaders import YoutubeLoader
 from langchain.document_loaders import YoutubeLoader
 
 
+from embedchain.loaders.base_loader import BaseLoader
 from embedchain.utils import clean_string
 from embedchain.utils import clean_string
 
 
 
 
-class YoutubeVideoLoader:
+class YoutubeVideoLoader(BaseLoader):
     def load_data(self, url):
     def load_data(self, url):
         """Load data from a Youtube video."""
         """Load data from a Youtube video."""
         loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)
         loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)