Переглянути джерело

[Improvement] Add support for min chunk size (#1007)

Deven Patel 1 рік тому
батько
коміт
c0ee680546

+ 0 - 1
README.md

@@ -71,7 +71,6 @@ elon_bot = App()
 # Embed online resources
 # Embed online resources
 elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk")
 elon_bot.add("https://en.wikipedia.org/wiki/Elon_Musk")
 elon_bot.add("https://www.forbes.com/profile/elon-musk")
 elon_bot.add("https://www.forbes.com/profile/elon-musk")
-elon_bot.add("https://www.youtube.com/watch?v=RcYjXbSJBN8")
 
 
 # Query the bot
 # Query the bot
 elon_bot.query("How many companies does Elon Musk run and name those?")
 elon_bot.query("How many companies does Elon Musk run and name those?")

+ 2 - 0
docs/api-reference/advanced/configuration.mdx

@@ -180,6 +180,8 @@ Alright, let's dive into what each key means in the yaml config above:
     - `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
     - `chunk_size` (Integer): The size of each chunk of text that is sent to the language model.
     - `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
     - `chunk_overlap` (Integer): The amount of overlap between each chunk of text.
     - `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here.
     - `length_function` (String): The function used to calculate the length of each chunk of text. In this case, it's set to 'len'. You can also use any function import directly as a string here.
+    - `min_chunk_size` (Integer): The minimum size of each chunk of text that is sent to the language model. Must be less than `chunk_size`, and greater than `chunk_overlap`.
+
 If you have questions about the configuration above, please feel free to reach out to us using one of the following methods:
 If you have questions about the configuration above, please feel free to reach out to us using one of the following methods:
 
 
 <Snippet file="get-help.mdx" />
 <Snippet file="get-help.mdx" />

+ 7 - 2
embedchain/chunkers/base_chunker.py

@@ -1,5 +1,8 @@
 import hashlib
 import hashlib
+import logging
+from typing import Optional
 
 
+from embedchain.config.add_config import ChunkerConfig
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.helpers.json_serializable import JSONSerializable
 from embedchain.models.data_type import DataType
 from embedchain.models.data_type import DataType
 
 
@@ -10,7 +13,7 @@ class BaseChunker(JSONSerializable):
         self.text_splitter = text_splitter
         self.text_splitter = text_splitter
         self.data_type = None
         self.data_type = None
 
 
-    def create_chunks(self, loader, src, app_id=None):
+    def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
         """
         """
         Loads data and chunks it.
         Loads data and chunks it.
 
 
@@ -23,6 +26,8 @@ class BaseChunker(JSONSerializable):
         documents = []
         documents = []
         chunk_ids = []
         chunk_ids = []
         idMap = {}
         idMap = {}
+        min_chunk_size = config.min_chunk_size if config is not None else 1
+        logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]
         doc_id = data_result["doc_id"]
@@ -44,7 +49,7 @@ class BaseChunker(JSONSerializable):
             for chunk in chunks:
             for chunk in chunks:
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = hashlib.sha256((chunk + url).encode()).hexdigest()
                 chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
                 chunk_id = f"{app_id}--{chunk_id}" if app_id is not None else chunk_id
-                if idMap.get(chunk_id) is None:
+                if idMap.get(chunk_id) is None and len(chunk) >= min_chunk_size:
                     idMap[chunk_id] = True
                     idMap[chunk_id] = True
                     chunk_ids.append(chunk_id)
                     chunk_ids.append(chunk_id)
                     documents.append(chunk)
                     documents.append(chunk)

+ 4 - 1
embedchain/chunkers/images.py

@@ -1,4 +1,5 @@
 import hashlib
 import hashlib
+import logging
 from typing import Optional
 from typing import Optional
 
 
 from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -20,7 +21,7 @@ class ImagesChunker(BaseChunker):
         )
         )
         super().__init__(image_splitter)
         super().__init__(image_splitter)
 
 
-    def create_chunks(self, loader, src, app_id=None):
+    def create_chunks(self, loader, src, app_id=None, config: Optional[ChunkerConfig] = None):
         """
         """
         Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
         Loads the image(s), and creates their corresponding embedding. This creates one chunk for each image
 
 
@@ -32,6 +33,8 @@ class ImagesChunker(BaseChunker):
         documents = []
         documents = []
         embeddings = []
         embeddings = []
         ids = []
         ids = []
+        min_chunk_size = config.min_chunk_size if config is not None else 0
+        logging.info(f"[INFO] Skipping chunks smaller than {min_chunk_size} characters")
         data_result = loader.load_data(src)
         data_result = loader.load_data(src)
         data_records = data_result["data"]
         data_records = data_result["data"]
         doc_id = data_result["doc_id"]
         doc_id = data_result["doc_id"]

+ 15 - 5
embedchain/config/add_config.py

@@ -1,4 +1,5 @@
 import builtins
 import builtins
+import logging
 from importlib import import_module
 from importlib import import_module
 from typing import Callable, Optional
 from typing import Callable, Optional
 
 
@@ -14,12 +15,21 @@ class ChunkerConfig(BaseConfig):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        chunk_size: Optional[int] = None,
-        chunk_overlap: Optional[int] = None,
+        chunk_size: Optional[int] = 2000,
+        chunk_overlap: Optional[int] = 0,
         length_function: Optional[Callable[[str], int]] = None,
         length_function: Optional[Callable[[str], int]] = None,
+        min_chunk_size: Optional[int] = 0,
     ):
     ):
-        self.chunk_size = chunk_size if chunk_size else 2000
-        self.chunk_overlap = chunk_overlap if chunk_overlap else 0
+        self.chunk_size = chunk_size
+        self.chunk_overlap = chunk_overlap
+        self.min_chunk_size = min_chunk_size
+        if self.min_chunk_size >= self.chunk_size:
+            raise ValueError(f"min_chunk_size {min_chunk_size} should be less than chunk_size {chunk_size}")
+        if self.min_chunk_size <= self.chunk_overlap:
+            logging.warn(
+                f"min_chunk_size {min_chunk_size} should be greater than chunk_overlap {chunk_overlap}, otherwise it is redundant."  # noqa:E501
+            )
+
         if isinstance(length_function, str):
         if isinstance(length_function, str):
             self.length_function = self.load_func(length_function)
             self.length_function = self.load_func(length_function)
         else:
         else:
@@ -37,7 +47,7 @@ class ChunkerConfig(BaseConfig):
 @register_deserializable
 @register_deserializable
 class LoaderConfig(BaseConfig):
 class LoaderConfig(BaseConfig):
     """
     """
-    Config for the chunker used in `add` method
+    Config for the loader used in `add` method
     """
     """
 
 
     def __init__(self):
     def __init__(self):

+ 5 - 3
embedchain/embedchain.py

@@ -196,7 +196,7 @@ class EmbedChain(JSONSerializable):
 
 
         data_formatter = DataFormatter(data_type, config, loader, chunker)
         data_formatter = DataFormatter(data_type, config, loader, chunker)
         documents, metadatas, _ids, new_chunks = self._load_and_embed(
         documents, metadatas, _ids, new_chunks = self._load_and_embed(
-            data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run, **kwargs
+            data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, config, dry_run, **kwargs
         )
         )
         if data_type in {DataType.DOCS_SITE}:
         if data_type in {DataType.DOCS_SITE}:
             self.is_docs_site_instance = True
             self.is_docs_site_instance = True
@@ -339,6 +339,7 @@ class EmbedChain(JSONSerializable):
         src: Any,
         src: Any,
         metadata: Optional[Dict[str, Any]] = None,
         metadata: Optional[Dict[str, Any]] = None,
         source_hash: Optional[str] = None,
         source_hash: Optional[str] = None,
+        add_config: Optional[AddConfig] = None,
         dry_run=False,
         dry_run=False,
         **kwargs: Optional[Dict[str, Any]],
         **kwargs: Optional[Dict[str, Any]],
     ):
     ):
@@ -359,12 +360,13 @@ class EmbedChain(JSONSerializable):
         app_id = self.config.id if self.config is not None else None
         app_id = self.config.id if self.config is not None else None
 
 
         # Create chunks
         # Create chunks
-        embeddings_data = chunker.create_chunks(loader, src, app_id=app_id)
+        embeddings_data = chunker.create_chunks(loader, src, app_id=app_id, config=add_config.chunker)
         # spread chunking results
         # spread chunking results
         documents = embeddings_data["documents"]
         documents = embeddings_data["documents"]
         metadatas = embeddings_data["metadatas"]
         metadatas = embeddings_data["metadatas"]
         ids = embeddings_data["ids"]
         ids = embeddings_data["ids"]
         new_doc_id = embeddings_data["doc_id"]
         new_doc_id = embeddings_data["doc_id"]
+        embeddings = embeddings_data.get("embeddings")
         if existing_doc_id and existing_doc_id == new_doc_id:
         if existing_doc_id and existing_doc_id == new_doc_id:
             print("Doc content has not changed. Skipping creating chunks and embeddings")
             print("Doc content has not changed. Skipping creating chunks and embeddings")
             return [], [], [], 0
             return [], [], [], 0
@@ -429,7 +431,7 @@ class EmbedChain(JSONSerializable):
         chunks_before_addition = self.db.count()
         chunks_before_addition = self.db.count()
 
 
         self.db.add(
         self.db.add(
-            embeddings=embeddings_data.get("embeddings", None),
+            embeddings=embeddings,
             documents=documents,
             documents=documents,
             metadatas=metadatas,
             metadatas=metadatas,
             ids=ids,
             ids=ids,

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 [tool.poetry]
 name = "embedchain"
 name = "embedchain"
-version = "0.1.32"
+version = "0.1.33"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 description = "Data platform for LLMs - Load, index, retrieve and sync any unstructured data"
 authors = [
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 13 - 0
tests/chunkers/test_base_chunker.py

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock
 import pytest
 import pytest
 
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
 from embedchain.models.data_type import DataType
 from embedchain.models.data_type import DataType
 
 
 
 
@@ -35,6 +36,18 @@ def chunker(text_splitter_mock, data_type):
     return chunker
     return chunker
 
 
 
 
+def test_create_chunks_with_config(chunker, text_splitter_mock, loader_mock, app_id, data_type):
+    text_splitter_mock.split_text.return_value = ["Chunk 1", "long chunk"]
+    loader_mock.load_data.return_value = {
+        "data": [{"content": "Content 1", "meta_data": {"url": "URL 1"}}],
+        "doc_id": "DocID",
+    }
+    config = ChunkerConfig(chunk_size=50, chunk_overlap=0, length_function=len, min_chunk_size=10)
+    result = chunker.create_chunks(loader_mock, "test_src", app_id, config)
+
+    assert result["documents"] == ["long chunk"]
+
+
 def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
 def test_create_chunks(chunker, text_splitter_mock, loader_mock, app_id, data_type):
     text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
     text_splitter_mock.split_text.return_value = ["Chunk 1", "Chunk 2"]
     loader_mock.load_data.return_value = {
     loader_mock.load_data.return_value = {

+ 2 - 2
tests/chunkers/test_image_chunker.py

@@ -11,7 +11,7 @@ class TestImageChunker(unittest.TestCase):
         Test the chunks generated by TextChunker.
         Test the chunks generated by TextChunker.
         # TODO: Not a very precise test.
         # TODO: Not a very precise test.
         """
         """
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = ImagesChunker(config=chunker_config)
         chunker = ImagesChunker(config=chunker_config)
         # Data type must be set manually in the test
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.IMAGES)
         chunker.set_data_type(DataType.IMAGES)
@@ -51,7 +51,7 @@ class TestImageChunker(unittest.TestCase):
         self.assertEqual(expected_chunks, result)
         self.assertEqual(expected_chunks, result)
 
 
     def test_word_count(self):
     def test_word_count(self):
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = ImagesChunker(config=chunker_config)
         chunker = ImagesChunker(config=chunker_config)
         chunker.set_data_type(DataType.IMAGES)
         chunker.set_data_type(DataType.IMAGES)
 
 

+ 9 - 9
tests/chunkers/test_text.py

@@ -10,12 +10,12 @@ class TestTextChunker:
         """
         """
         Test the chunks generated by TextChunker.
         Test the chunks generated by TextChunker.
         """
         """
-        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         # Data type must be set manually in the test
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         documents = result["documents"]
         assert len(documents) > 5
         assert len(documents) > 5
 
 
@@ -23,11 +23,11 @@ class TestTextChunker:
         """
         """
         Test the chunks generated by TextChunker with app_id
         Test the chunks generated by TextChunker with app_id
         """
         """
-        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=10, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         documents = result["documents"]
         assert len(documents) > 5
         assert len(documents) > 5
 
 
@@ -35,12 +35,12 @@ class TestTextChunker:
         """
         """
         Test that if an infinitely high chunk size is used, only one chunk is returned.
         Test that if an infinitely high chunk size is used, only one chunk is returned.
         """
         """
-        chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=9999999999, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         chunker = TextChunker(config=chunker_config)
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."
         # Data type must be set manually in the test
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         documents = result["documents"]
         assert len(documents) == 1
         assert len(documents) == 1
 
 
@@ -48,18 +48,18 @@ class TestTextChunker:
         """
         """
         Test that if a chunk size of one is used, every character is a chunk.
         Test that if a chunk size of one is used, every character is a chunk.
         """
         """
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         chunker = TextChunker(config=chunker_config)
         # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
         # We can't test with lorem ipsum because chunks are deduped, so would be recurring characters.
         text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
         text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c"""
         # Data type must be set manually in the test
         # Data type must be set manually in the test
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
-        result = chunker.create_chunks(MockLoader(), text)
+        result = chunker.create_chunks(MockLoader(), text, chunker_config)
         documents = result["documents"]
         documents = result["documents"]
         assert len(documents) == len(text)
         assert len(documents) == len(text)
 
 
     def test_word_count(self):
     def test_word_count(self):
-        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len)
+        chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, length_function=len, min_chunk_size=0)
         chunker = TextChunker(config=chunker_config)
         chunker = TextChunker(config=chunker_config)
         chunker.set_data_type(DataType.TEXT)
         chunker.set_data_type(DataType.TEXT)
 
 

+ 1 - 1
tests/embedchain/test_add.py

@@ -33,7 +33,7 @@ def test_add_forced_type(app):
 
 
 
 
 def test_dry_run(app):
 def test_dry_run(app):
-    chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0)
+    chunker_config = ChunkerConfig(chunk_size=1, chunk_overlap=0, min_chunk_size=0)
     text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
     text = """0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"""
 
 
     result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)
     result = app.add(source=text, config=AddConfig(chunker=chunker_config), dry_run=True)