Procházet zdrojové kódy

refactor: add data-format restructuring (#92)

aaishikdutta před 2 roky
rodič
revize
ae1e21833c

+ 1 - 0
embedchain/data_formatter/__init__.py

@@ -0,0 +1 @@
+from .data_formatter import DataFormatter

+ 66 - 0
embedchain/data_formatter/data_formatter.py

@@ -0,0 +1,66 @@
+from embedchain.loaders.youtube_video import YoutubeVideoLoader
+from embedchain.loaders.pdf_file import PdfFileLoader
+from embedchain.loaders.web_page import WebPageLoader
+from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
+from embedchain.loaders.local_text import LocalTextLoader
+from embedchain.loaders.docx_file import DocxFileLoader
+from embedchain.chunkers.youtube_video import YoutubeVideoChunker
+from embedchain.chunkers.pdf_file import PdfFileChunker
+from embedchain.chunkers.web_page import WebPageChunker
+from embedchain.chunkers.qna_pair import QnaPairChunker
+from embedchain.chunkers.text import TextChunker
+from embedchain.chunkers.docx_file import DocxFileChunker
+
+
+class DataFormatter:
+    """
+    DataFormatter is an internal utility class which abstracts the mapping for
+    loaders and chunkers to the data_type entered by the user in their
+    .add or .add_local method call 
+    """
+    def __init__(self, data_type):
+        self.loader = self._get_loader(data_type)
+        self.chunker = self._get_chunker(data_type)
+        
+    def _get_loader(self, data_type):
+        """
+        Returns the appropriate data loader for the given data type.
+
+        :param data_type: The type of the data to load.
+        :return: The loader for the given data type.
+        :raises ValueError: If an unsupported data type is provided.
+        """
+        loaders = {
+            'youtube_video': YoutubeVideoLoader(),
+            'pdf_file': PdfFileLoader(),
+            'web_page': WebPageLoader(),
+            'qna_pair': LocalQnaPairLoader(),
+            'text': LocalTextLoader(),
+            'docx': DocxFileLoader(),
+        }
+        if data_type in loaders:
+            return loaders[data_type]
+        else:
+            raise ValueError(f"Unsupported data type: {data_type}")
+
+    def _get_chunker(self, data_type):
+        """
+        Returns the appropriate chunker for the given data type.
+
+        :param data_type: The type of the data to chunk.
+        :return: The chunker for the given data type.
+        :raises ValueError: If an unsupported data type is provided.
+        """
+        chunkers = {
+            'youtube_video': YoutubeVideoChunker(),
+            'pdf_file': PdfFileChunker(),
+            'web_page': WebPageChunker(),
+            'qna_pair': QnaPairChunker(),
+            'text': TextChunker(),
+            'docx': DocxFileChunker(),
+        }
+        if data_type in chunkers:
+            return chunkers[data_type]
+        else:
+            raise ValueError(f"Unsupported data type: {data_type}")
+

+ 9 - 62
embedchain/embedchain.py

@@ -9,21 +9,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings
 from langchain.memory import ConversationBufferMemory
 from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig
 from embedchain.config.QueryConfig import DEFAULT_PROMPT
-
-from embedchain.loaders.youtube_video import YoutubeVideoLoader
-from embedchain.loaders.pdf_file import PdfFileLoader
-from embedchain.loaders.web_page import WebPageLoader
-from embedchain.loaders.local_qna_pair import LocalQnaPairLoader
-from embedchain.loaders.local_text import LocalTextLoader
-from embedchain.loaders.docx_file import DocxFileLoader
-from embedchain.chunkers.youtube_video import YoutubeVideoChunker
-from embedchain.chunkers.pdf_file import PdfFileChunker
-from embedchain.chunkers.web_page import WebPageChunker
-from embedchain.chunkers.qna_pair import QnaPairChunker
-from embedchain.chunkers.text import TextChunker
-from embedchain.chunkers.docx_file import DocxFileChunker
-from embedchain.vectordb.chroma_db import ChromaDB
-
+from embedchain.data_formatter import DataFormatter
 
 gpt4all_model = None
 
@@ -49,48 +35,9 @@ class EmbedChain:
         self.collection = self.config.db.collection
         self.user_asks = []
 
-    def _get_loader(self, data_type):
-        """
-        Returns the appropriate data loader for the given data type.
-
-        :param data_type: The type of the data to load.
-        :return: The loader for the given data type.
-        :raises ValueError: If an unsupported data type is provided.
-        """
-        loaders = {
-            'youtube_video': YoutubeVideoLoader(),
-            'pdf_file': PdfFileLoader(),
-            'web_page': WebPageLoader(),
-            'qna_pair': LocalQnaPairLoader(),
-            'text': LocalTextLoader(),
-            'docx': DocxFileLoader(),
-        }
-        if data_type in loaders:
-            return loaders[data_type]
-        else:
-            raise ValueError(f"Unsupported data type: {data_type}")
-
-    def _get_chunker(self, data_type):
-        """
-        Returns the appropriate chunker for the given data type.
-
-        :param data_type: The type of the data to chunk.
-        :return: The chunker for the given data type.
-        :raises ValueError: If an unsupported data type is provided.
-        """
-        chunkers = {
-            'youtube_video': YoutubeVideoChunker(),
-            'pdf_file': PdfFileChunker(),
-            'web_page': WebPageChunker(),
-            'qna_pair': QnaPairChunker(),
-            'text': TextChunker(),
-            'docx': DocxFileChunker(),
-        }
-        if data_type in chunkers:
-            return chunkers[data_type]
-        else:
-            raise ValueError(f"Unsupported data type: {data_type}")
+    
 
+    
     def add(self, data_type, url, config: AddConfig = None):
         """
         Adds the data from the given URL to the vector db.
@@ -103,10 +50,10 @@ class EmbedChain:
         """
         if config is None:
             config = AddConfig()
-        loader = self._get_loader(data_type)
-        chunker = self._get_chunker(data_type)
+        
+        data_formatter = DataFormatter(data_type)
         self.user_asks.append([data_type, url])
-        self.load_and_embed(loader, chunker, url)
+        self.load_and_embed(data_formatter.loader, data_formatter.chunker, url)
 
     def add_local(self, data_type, content, config: AddConfig = None):
         """
@@ -120,10 +67,10 @@ class EmbedChain:
         """
         if config is None:
             config = AddConfig()
-        loader = self._get_loader(data_type)
-        chunker = self._get_chunker(data_type)
+        
+        data_formatter = DataFormatter(data_type)
         self.user_asks.append([data_type, content])
-        self.load_and_embed(loader, chunker, content)
+        self.load_and_embed(data_formatter.loader, data_formatter.chunker, content)
 
     def load_and_embed(self, loader, chunker, src):
         """