Przeglądaj źródła

featL AddConfig should allow configuring Chunker (#200)

Anupam Singh 2 lat temu
rodzic
commit
eda28cc491

+ 38 - 3
README.md

@@ -377,8 +377,17 @@ config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
             ))
 naval_chat_bot = App(config)
 
-add_config = AddConfig() # Currently no options
-naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", add_config)
+# Example: define your own chunker config for `youtube_video`
+youtube_add_config = {
+        "chunker": {
+                "chunk_size": 1000,
+                "chunk_overlap": 100,
+                "length_function": len,
+        }
+}
+naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(**youtube_add_config))
+
+add_config = AddConfig()
 naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)
 naval_chat_bot.add("web_page", "https://nav.al/feedback", add_config)
 naval_chat_bot.add("web_page", "https://nav.al/agi", add_config)
@@ -450,13 +459,39 @@ This section describes all possible config options.
 
 #### **Add Config**
 
+|option|description|type|default|
+|---|---|---|---|
+|chunker|chunker config|ChunkerConfig|Default values for chunker depends on the `data_type`. Please refer [ChunkerConfig](#chunker-config)|
+|loader|loader config|LoaderConfig|None|
+
+##### **Chunker Config**
+
+|option|description|type|default|
+|---|---|---|---|
+|chunk_size|Maximum size of chunks to return|int|Default value for various `data_type` mentioned below|
+|chunk_overlap|Overlap in characters between chunks|int|Default value for various `data_type` mentioned below|
+|length_function|Function that measures the length of given chunks|typing.Callable|Default value for various `data_type` mentioned below|
+
+Default values of chunker config parameters for different `data_type`:
+
+|data_type|chunk_size|chunk_overlap|length_function|
+|---|---|---|---|
+|docx|1000|0|len|
+|text|300|0|len|
+|qna_pair|300|0|len|
+|web_page|500|0|len|
+|pdf_file|1000|0|len|
+|youtube_video|2000|0|len|
+
+##### **Loader Config**
+
 _coming soon_
 
 #### **Query Config**
 
 |option|description|type|default|
 |---|---|---|---|
-|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: $query Helpful Answer:")|
+|template|custom template for prompt|Template|Template("Use the following pieces of context to answer the query at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. \$context Query: \$query Helpful Answer:")|
 |history|include conversation history from your client or database|any (recommendation: list[str])|None
 |stream|control if response is streamed back to the user|bool|False|
 

+ 7 - 2
embedchain/chunkers/docx_file.py

@@ -1,8 +1,11 @@
+from typing import Optional
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
 
+
 TEXT_SPLITTER_CHUNK_PARAMS = {
     "chunk_size": 1000,
     "chunk_overlap": 0,
@@ -11,6 +14,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 class DocxFileChunker(BaseChunker):
-    def __init__(self):
-        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = TEXT_SPLITTER_CHUNK_PARAMS
+        text_splitter = RecursiveCharacterTextSplitter(**config)
         super().__init__(text_splitter)

+ 7 - 3
embedchain/chunkers/pdf_file.py

@@ -1,4 +1,6 @@
+from typing import Optional
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 class PdfFileChunker(BaseChunker):
-    def __init__(self):
-        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
-        super().__init__(text_splitter)
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = TEXT_SPLITTER_CHUNK_PARAMS
+        text_splitter = RecursiveCharacterTextSplitter(**config)
+        super().__init__(text_splitter)

+ 6 - 2
embedchain/chunkers/qna_pair.py

@@ -1,4 +1,6 @@
+from typing import Optional
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 class QnaPairChunker(BaseChunker):
-    def __init__(self):
-        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = TEXT_SPLITTER_CHUNK_PARAMS
+        text_splitter = RecursiveCharacterTextSplitter(**config)
         super().__init__(text_splitter)

+ 6 - 2
embedchain/chunkers/text.py

@@ -1,4 +1,6 @@
+from typing import Optional
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 class TextChunker(BaseChunker):
-    def __init__(self):
-        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = TEXT_SPLITTER_CHUNK_PARAMS
+        text_splitter = RecursiveCharacterTextSplitter(**config)
         super().__init__(text_splitter)

+ 6 - 2
embedchain/chunkers/web_page.py

@@ -1,4 +1,6 @@
+from typing import Optional
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 class WebPageChunker(BaseChunker):
-    def __init__(self):
-        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = TEXT_SPLITTER_CHUNK_PARAMS
+        text_splitter = RecursiveCharacterTextSplitter(**config)
         super().__init__(text_splitter)

+ 7 - 3
embedchain/chunkers/youtube_video.py

@@ -1,4 +1,6 @@
+from typing import Optional
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.AddConfig import ChunkerConfig
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 
@@ -11,6 +13,8 @@ TEXT_SPLITTER_CHUNK_PARAMS = {
 
 
 class YoutubeVideoChunker(BaseChunker):
-    def __init__(self):
-        text_splitter = RecursiveCharacterTextSplitter(**TEXT_SPLITTER_CHUNK_PARAMS)
-        super().__init__(text_splitter)
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = TEXT_SPLITTER_CHUNK_PARAMS
+        text_splitter = RecursiveCharacterTextSplitter(**config)
+        super().__init__(text_splitter)

+ 26 - 2
embedchain/config/AddConfig.py

@@ -1,8 +1,32 @@
+from typing import Callable, Optional
 from embedchain.config.BaseConfig import BaseConfig
 
+
+class ChunkerConfig(BaseConfig):
+    """
+    Config for the chunker used in `add` method
+    """
+    def __init__(self,
+                 chunk_size: Optional[int] = 4000,
+                 chunk_overlap: Optional[int] = 200,
+                 length_function: Optional[Callable[[str], int]] = len):
+        self.chunk_size = chunk_size
+        self.chunk_overlap = chunk_overlap
+        self.length_function = length_function
+
+class LoaderConfig(BaseConfig):
+    """
+    Config for the chunker used in `add` method
+    """
+    def __init__(self):
+        pass
+
 class AddConfig(BaseConfig):
     """
     Config for the `add` method.
     """
-    def __init__(self):
-        pass
+    def __init__(self,
+                 chunker: Optional[ChunkerConfig] = None,
+                 loader: Optional[LoaderConfig] = None):
+        self.loader = loader
+        self.chunker = chunker

+ 13 - 13
embedchain/data_formatter/data_formatter.py

@@ -1,3 +1,4 @@
+from embedchain.config import AddConfig
 from embedchain.loaders.youtube_video import YoutubeVideoLoader
 from embedchain.loaders.pdf_file import PdfFileLoader
 from embedchain.loaders.web_page import WebPageLoader
@@ -18,11 +19,11 @@ class DataFormatter:
     loaders and chunkers to the data_type entered by the user in their
     .add or .add_local method call 
     """
-    def __init__(self, data_type):
-        self.loader = self._get_loader(data_type)
-        self.chunker = self._get_chunker(data_type)
-        
-    def _get_loader(self, data_type):
+    def __init__(self, data_type: str, config: AddConfig):
+        self.loader = self._get_loader(data_type, config.loader)
+        self.chunker = self._get_chunker(data_type, config.chunker)
+
+    def _get_loader(self, data_type, config):
         """
         Returns the appropriate data loader for the given data type.
 
@@ -43,7 +44,7 @@ class DataFormatter:
         else:
             raise ValueError(f"Unsupported data type: {data_type}")
 
-    def _get_chunker(self, data_type):
+    def _get_chunker(self, data_type, config):
         """
         Returns the appropriate chunker for the given data type.
 
@@ -52,15 +53,14 @@ class DataFormatter:
         :raises ValueError: If an unsupported data type is provided.
         """
         chunkers = {
-            'youtube_video': YoutubeVideoChunker(),
-            'pdf_file': PdfFileChunker(),
-            'web_page': WebPageChunker(),
-            'qna_pair': QnaPairChunker(),
-            'text': TextChunker(),
-            'docx': DocxFileChunker(),
+            'youtube_video': YoutubeVideoChunker(config),
+            'pdf_file': PdfFileChunker(config),
+            'web_page': WebPageChunker(config),
+            'qna_pair': QnaPairChunker(config),
+            'text': TextChunker(config),
+            'docx': DocxFileChunker(config),
         }
         if data_type in chunkers:
             return chunkers[data_type]
         else:
             raise ValueError(f"Unsupported data type: {data_type}")
-

+ 4 - 7
embedchain/embedchain.py

@@ -37,9 +37,6 @@ class EmbedChain:
         self.collection = self.config.db.collection
         self.user_asks = []
 
-    
-
-    
     def add(self, data_type, url, config: AddConfig = None):
         """
         Adds the data from the given URL to the vector db.
@@ -52,8 +49,8 @@ class EmbedChain:
         """
         if config is None:
             config = AddConfig()
-        
-        data_formatter = DataFormatter(data_type)
+
+        data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([data_type, url])
         self.load_and_embed(data_formatter.loader, data_formatter.chunker, url)
 
@@ -69,8 +66,8 @@ class EmbedChain:
         """
         if config is None:
             config = AddConfig()
-        
-        data_formatter = DataFormatter(data_type)
+
+        data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([data_type, content])
         self.load_and_embed(data_formatter.loader, data_formatter.chunker, content)