Ver Fonte

[Improvement] Add support for min chunk size (#1007)

Deven Patel há 1 ano atrás
pai
commit
c0ee680546

+ 0 - 1
README.md

@@ -71,7 +71,6 @@ elon_bot = App()
 # Embed online resources
 elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk")
 elon_bot.add("https://www.forbes.com/profile/elon-musk")
-elon_bot.add("https://www.youtube.com/watch?v=RcYjXbSJBN8")
 
 # Query the bot
 elon_bot.query("How many companies does Elon Musk run and name those?")

+ 2 - 0
docs/api-reference/advanced/configuration.mdx

@@ -180,6 +180,8 @@ Alright, let's dive into what each key means in the yaml config above:
     - `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
     - `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
     - `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here.
+    - `min_chunk_size` (Integer): The minimum size of each chunk of text that is sent to the language model. Must be less than `chunk_size`, and greater than `chunk_overlap`.
+
 If you have questions about the configuration above, please feel free to reach out to us using one of the following methods:
 
 <Snippet file="get-help.mdx" />

+ 7 - 2
embedchain/chunkers/base_chunker.py

@@ -1,5 +1,8 @@
 import hashlib
+import logging
+from typing import Optional
 
+from embedchain.config.add_config import ChunkerConfig
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.models.data_type import DataType
 
@@ -10,7 +13,7 @@ class BaseChunker(JSONSerializable):
         self.text_splitter = text_splitter
         self.data_type = None
 
-    def create_chunks(self, loader, src, app_id=None):
+    def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
         """
         Loads data and chunks it.
 
@@ -23,6 +26,8 @@ class BaseChunker(JSONSerializable):
         documents = []
         chunk_ids = []
         idMap = {}
+        min_chunk_size = config.min_chunk_size if config is not None else 1
+        logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]
@@ -44,7 +49,7 @@ class BaseChunker(JSONSerializable):
             for chunk in chunks:
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
-                if idMap.get(chunk_id) is None:
+                if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
                     idMap[chunk_id] = True
                     chunk_ids.append(chunk_id)
                     documents.append(chunk)

+ 4 - 1
embedchain/chunkers/images.py

@@ -1,4 +1,5 @@
 import hashlib
+import logging
 from typing import Optional
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -20,7 +21,7 @@ class ImagesChunker(BaseChunker):
         )
         super().__init__(image_splitter)
 
-    def create_chunks(self, loader, src, app_id=None):
+    def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
         """
         Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
 
@@ -32,6 +33,8 @@ class ImagesChunker(BaseChunker):
         documents = []
         embeddings = []
         ids = []
+        min_chunk_size = config.min_chunk_size if config is not None else 0
+        logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]

+ 15 - 5
embedchain/config/add_config.py

@@ -1,4 +1,5 @@
 import builtins
+import logging
 from importlib import import_module
 from typing import Callable, Optional
 
@@ -14,12 +15,21 @@ class ChunkerConfig(BaseConfig):
 
     def __init__(
         self,
-        chunk_size: Optional[int] = None,
-        chunk_overlap: Optional[int] = None,
+        chunk_size: Optional[int] = 2000,
+        chunk_overlap: Optional[int] = 0,
         length_function: Optional[Callable[[str], int]] = None,
+        min_chunk_size: Optional[int] = 0,
     ):
-        self.chunk_size = chunk_size if chunk_size else 2000
-        self.chunk_overlap = chunk_overlap if chunk_overlap else 0
+        self.chunk_size = chunk_size
+        self.chunk_overlap = chunk_overlap
+        self.min_chunk_size = min_chunk_size
+        if self.min_chunk_size >= self.chunk_size:
+            raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
+        if self.min_chunk_size <= self.chunk_overlap:
+            logging.warn(
+                f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant."  # noqa:E501
+            )
+
         if isinstance(length_function, str):
             self.length_function = self.load_func(length_function)
         else:
@@ -37,7 +47,7 @@ class ChunkerConfig(BaseConfig):
 @register_deserializable
 class LoaderConfig(BaseConfig):
     """
-    Config for the chunker used in `add` method
+    Config for the loader used in `add` method
     """
 
     def __init__(self):

+ 5 - 3
embedchain/embedchain.py

@@ -196,7 +196,7 @@ class EmbedChain(JSONSerializable):
 
         data_formatter = DataFormatter(data_type, config, loader, chunker)
         documents, metadatas, _ids, new_chunks = self._load_and_embed(
-            data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
+            data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs
         )
         if data_type in {DataType.DOCS_SITE}:
             self.is_docs_site_instance = True
@@ -339,6 +339,7 @@ class EmbedChain(JSONSerializable):
         src: Any,
         metadata: Optional[Dict[str, Any]] = None,
         source_hash: Optional[str] = None,
+        add_config: Optional[AddConfig] = None,
         dry_run=False,
         **kwargs: Optional[Dict[str, Any]],
     ):
@@ -359,12 +360,13 @@ class EmbedChain(JSONSerializable):
         app_id = self.config.id if self.config is not None else None
 
         # Create chunks
-        embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
+        embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
         # spread chunking results
         documents = embeddings_data["documents"]
         metadatas = embeddings_data["metadatas"]
         ids = embeddings_data["ids"]
         new_doc_id = embeddings_data["doc_id"]
+        embeddings = embeddings_data.get("embeddings")
         if existing_doc_id and existing_doc_id == new_doc_id:
             print("Doc content has not changed. Skipping creating chunks and embeddings")
             return [], [], [], 0
@@ -429,7 +431,7 @@ class EmbedChain(JSONSerializable):
         chunks_before_addition = self.db.count()
 
         self.db.add(
-            embeddings=embeddings_data.get("embeddings", None),
+            embeddings=embeddings,
             documents=documents,
             metadatas=metadatas,
             ids=ids,

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.32"
+version = "0.1.33"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 13 - 0
tests/chunkers/test_base_chunker.py

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
 import pytest
 
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
 from embedchain.models.data_type import DataType
 
 
@@ -35,6 +36,18 @@ def chunker(text_splitter_mock, data_type):
     return chunker
 
 
+def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type):
+    text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"]
+    loader_mock.load_data.return_value = {
+        "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
+        "doc_id": "DocID",
+    }
+    config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10)
+    result = chunker.create_chunks(loader_mock, "test_src", app_id, config)
+
+    assert result["documents"] == ["long chunk"]
+
+
 def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
     text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
     loader_mock.load_data.return_value = {

+ 2 - 2
tests/chunkers/test_image_chunker.py

@@ -11,7 +11,7 @@ class TestImageChunker(unittest.TestCase):
         Test the chunks generated by TextChunker.
         # TODO: Not a very precise test.
         """
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = ImagesChunker(config=chunker_config)
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.IMAGES)
@@ -51,7 +51,7 @@ class TestImageChunker(unittest.TestCase):
         self.assertEqual(expected_chunks, result)
 
     def test_word_count(self):
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = ImagesChunker(config=chunker_config)
         chunker.set_data_type(DataType.IMAGES)
 

+ 9 - 9
tests/chunkers/test_text.py

@@ -10,12 +10,12 @@ class TestTextChunker:
         """
         Test the chunks generated by TextChunker.
         """
-        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         assert len(documents) > 5
 
@@ -23,11 +23,11 @@ class TestTextChunker:
         """
         Test the chunks generated by TextChunker with app_id
         """
-        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         assert len(documents) > 5
 
@@ -35,12 +35,12 @@ class TestTextChunker:
         """
         Test that if an infinitely high chunk size is used, only one chunk is returned.
         """
-        chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         assert len(documents) == 1
 
@@ -48,18 +48,18 @@ class TestTextChunker:
         """
         Test that if a chunk size of one is used, every character is a chunk.
         """
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
         text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         assert len(documents) == len(text)
 
     def test_word_count(self):
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         chunker.set_data_type(DataType.TEXT)
 

+ 1 - 1
tests/embedchain/test_add.py

@@ -33,7 +33,7 @@ def test_add_forced_type(app):
 
 
 def test_dry_run(app):
-    chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
+    chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0)
     text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
 
     result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)