Browse Source

chore: load chunker from config (#270)

cachho 2 years ago
parent
commit
9c58627372

+ 3 - 9
docs/advanced/configuration.mdx

@@ -13,7 +13,7 @@ Here's the readme example with configuration options.
 ```python
 import os
 from embedchain import App
-from embedchain.config import InitConfig, AddConfig, QueryConfig
+from embedchain.config import InitConfig, AddConfig, QueryConfig, ChunkerConfig
 from chromadb.utils import embedding_functions
 
 # Example: use your own embedding function
@@ -25,14 +25,8 @@ config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
 naval_chat_bot = App(config)
 
 # Example: define your own chunker config for `youtube_video`
-youtube_add_config = {
-        "chunker": {
-                "chunk_size": 1000,
-                "chunk_overlap": 100,
-                "length_function": len,
-        }
-}
-naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(**youtube_add_config))
+chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len)
+naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config))
 
 add_config = AddConfig()
 naval_chat_bot.add("pdf_file", "https://navalmanack.s3.amazonaws.com/Eric-Jorgenson_The-Almanack-of-Naval-Ravikant_Final.pdf", add_config)

+ 6 - 8
embedchain/chunkers/docx_file.py

@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
 
-TEXT_SPLITTER_CHUNK_PARAMS = {
-    "chunk_size": 1000,
-    "chunk_overlap": 0,
-    "length_function": len,
-}
-
 
 class DocxFileChunker(BaseChunker):
     """Chunker for .docx file."""
 
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
-            config = TEXT_SPLITTER_CHUNK_PARAMS
-        text_splitter = RecursiveCharacterTextSplitter(**config)
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
         super().__init__(text_splitter)

+ 6 - 8
embedchain/chunkers/pdf_file.py

@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
 
-TEXT_SPLITTER_CHUNK_PARAMS = {
-    "chunk_size": 1000,
-    "chunk_overlap": 0,
-    "length_function": len,
-}
-
 
 class PdfFileChunker(BaseChunker):
     """Chunker for PDF file."""
 
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
-            config = TEXT_SPLITTER_CHUNK_PARAMS
-        text_splitter = RecursiveCharacterTextSplitter(**config)
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
         super().__init__(text_splitter)

+ 6 - 8
embedchain/chunkers/qna_pair.py

@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
 
-TEXT_SPLITTER_CHUNK_PARAMS = {
-    "chunk_size": 300,
-    "chunk_overlap": 0,
-    "length_function": len,
-}
-
 
 class QnaPairChunker(BaseChunker):
     """Chunker for QnA pair."""
 
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
-            config = TEXT_SPLITTER_CHUNK_PARAMS
-        text_splitter = RecursiveCharacterTextSplitter(**config)
+            config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
         super().__init__(text_splitter)

+ 6 - 8
embedchain/chunkers/text.py

@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
 
-TEXT_SPLITTER_CHUNK_PARAMS = {
-    "chunk_size": 300,
-    "chunk_overlap": 0,
-    "length_function": len,
-}
-
 
 class TextChunker(BaseChunker):
     """Chunker for text."""
 
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
-            config = TEXT_SPLITTER_CHUNK_PARAMS
-        text_splitter = RecursiveCharacterTextSplitter(**config)
+            config = ChunkerConfig(chunk_size=300, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
         super().__init__(text_splitter)

+ 6 - 8
embedchain/chunkers/web_page.py

@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
 
-TEXT_SPLITTER_CHUNK_PARAMS = {
-    "chunk_size": 500,
-    "chunk_overlap": 0,
-    "length_function": len,
-}
-
 
 class WebPageChunker(BaseChunker):
     """Chunker for web page."""
 
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
-            config = TEXT_SPLITTER_CHUNK_PARAMS
-        text_splitter = RecursiveCharacterTextSplitter(**config)
+            config = ChunkerConfig(chunk_size=500, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
         super().__init__(text_splitter)

+ 6 - 8
embedchain/chunkers/youtube_video.py

@@ -5,18 +5,16 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config.AddConfig import ChunkerConfig
 
-TEXT_SPLITTER_CHUNK_PARAMS = {
-    "chunk_size": 2000,
-    "chunk_overlap": 0,
-    "length_function": len,
-}
-
 
 class YoutubeVideoChunker(BaseChunker):
     """Chunker for Youtube video."""
 
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
-            config = TEXT_SPLITTER_CHUNK_PARAMS
-        text_splitter = RecursiveCharacterTextSplitter(**config)
+            config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
         super().__init__(text_splitter)

+ 6 - 6
embedchain/config/AddConfig.py

@@ -10,13 +10,13 @@ class ChunkerConfig(BaseConfig):
 
     def __init__(
         self,
-        chunk_size: Optional[int] = 4000,
-        chunk_overlap: Optional[int] = 200,
-        length_function: Optional[Callable[[str], int]] = len,
+        chunk_size: Optional[int] = None,
+        chunk_overlap: Optional[int] = None,
+        length_function: Optional[Callable[[str], int]] = None,
     ):
-        self.chunk_size = chunk_size
-        self.chunk_overlap = chunk_overlap
-        self.length_function = length_function
+        self.chunk_size = chunk_size if chunk_size else 2000
+        self.chunk_overlap = chunk_overlap if chunk_overlap else 0
+        self.length_function = length_function if length_function else len
 
 
 class LoaderConfig(BaseConfig):

+ 1 - 1
embedchain/config/__init__.py

@@ -1,4 +1,4 @@
-from .AddConfig import AddConfig  # noqa: F401
+from .AddConfig import AddConfig, ChunkerConfig  # noqa: F401
 from .BaseConfig import BaseConfig  # noqa: F401
 from .ChatConfig import ChatConfig  # noqa: F401
 from .InitConfig import InitConfig  # noqa: F401

+ 2 - 5
tests/chunkers/test_text.py

@@ -3,6 +3,7 @@
 import unittest
 
 from embedchain.chunkers.text import TextChunker
+from embedchain.config import ChunkerConfig
 
 
 class TestTextChunker(unittest.TestCase):
@@ -11,11 +12,7 @@ class TestTextChunker(unittest.TestCase):
         Test the chunks generated by TextChunker.
         # TODO: Not a very precise test.
         """
-        chunker_config = {
-            "chunk_size": 10,
-            "chunk_overlap": 0,
-            "length_function": len,
-        }
+        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."