Explorar el Código

Introduce chunker config in yaml config (#907)

Sidharth Mohanty hace 1 año
padre
commit
a1de238716

+ 7 - 2
embedchain/apps/app.py

@@ -2,7 +2,7 @@ from typing import Optional
 
 import yaml
 
-from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
+from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
@@ -38,6 +38,7 @@ class App(EmbedChain):
         embedder: BaseEmbedder = None,
         embedder_config: Optional[BaseEmbedderConfig] = None,
         system_prompt: Optional[str] = None,
+        chunker: Optional[ChunkerConfig] = None,
     ):
         """
         Initialize a new `App` instance.
@@ -97,6 +98,9 @@ class App(EmbedChain):
         if embedder is None:
             embedder = OpenAIEmbedder(config=embedder_config)
 
+        self.chunker = None
+        if chunker:
+            self.chunker = ChunkerConfig(**chunker)
         # Type check assignments
         if not isinstance(llm, BaseLlm):
             raise TypeError(
@@ -137,6 +141,7 @@ class App(EmbedChain):
         llm_config_data = config_data.get("llm", {})
         db_config_data = config_data.get("vectordb", {})
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
+        chunker_config_data = config_data.get("chunker", {})
 
         app_config = AppConfig(**app_config_data.get("config", {}))
 
@@ -148,4 +153,4 @@ class App(EmbedChain):
 
         embedder_provider = embedding_model_config_data.get("provider", "openai")
         embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
-        return cls(config=app_config, llm=llm, db=db, embedder=embedder)
+        return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)

+ 14 - 1
embedchain/config/add_config.py

@@ -1,3 +1,5 @@
+import builtins
+from importlib import import_module
 from typing import Callable, Optional
 
 from embedchain.config.base_config import BaseConfig
@@ -18,7 +20,18 @@ class ChunkerConfig(BaseConfig):
     ):
         self.chunk_size = chunk_size if chunk_size else 2000
         self.chunk_overlap = chunk_overlap if chunk_overlap else 0
-        self.length_function = length_function if length_function else len
+        if isinstance(length_function, str):
+            self.length_function = self.load_func(length_function)
+        else:
+            self.length_function = length_function if length_function else len
+
+    def load_func(self, dotpath: str):
+        if "." not in dotpath:
+            return getattr(builtins, dotpath)
+        else:
+            module_, func = dotpath.rsplit(".", maxsplit=1)
+            m = import_module(module_)
+            return getattr(m, func)
 
 
 @register_deserializable

+ 8 - 4
embedchain/embedchain.py

@@ -10,15 +10,14 @@ from dotenv import load_dotenv
 from langchain.docstore.document import Document
 
 from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config import AddConfig, BaseLlmConfig
+from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.apps.base_app_config import BaseAppConfig
 from embedchain.data_formatter import DataFormatter
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import (DataType, DirectDataType,
-                                         IndirectDataType, SpecialDataType)
+from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.utils import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
@@ -84,6 +83,7 @@ class EmbedChain(JSONSerializable):
         # Attributes that aren't subclass related.
         self.user_asks = []
 
+        self.chunker: ChunkerConfig = None
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -157,7 +157,11 @@ class EmbedChain(JSONSerializable):
         :return: source_hash, a md5-hash of the source, in hexadecimal representation.
         :rtype: str
         """
-        if config is None:
+        if config is not None:
+            pass
+        elif self.chunker is not None:
+            config = AddConfig(chunker=self.chunker)
+        else:
             config = AddConfig()
 
         try:

+ 8 - 1
embedchain/pipeline.py

@@ -9,7 +9,7 @@ import requests
 import yaml
 
 from embedchain import Client
-from embedchain.config import PipelineConfig
+from embedchain.config import PipelineConfig, ChunkerConfig
 from embedchain.embedchain import CONFIG_DIR, EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
@@ -44,6 +44,7 @@ class Pipeline(EmbedChain):
         yaml_path: str = None,
         log_level=logging.INFO,
         auto_deploy: bool = False,
+        chunker: ChunkerConfig = None,
     ):
         """
         Initialize a new `App` instance.
@@ -84,6 +85,10 @@ class Pipeline(EmbedChain):
         # pipeline_id from the backend
         self.id = None
 
+        self.chunker = None
+        if chunker:
+            self.chunker = ChunkerConfig(**chunker)
+
         self.config = config or PipelineConfig()
         self.name = self.config.name
 
@@ -366,6 +371,7 @@ class Pipeline(EmbedChain):
         db_config_data = config_data.get("vectordb", {})
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         llm_config_data = config_data.get("llm", {})
+        chunker_config_data = config_data.get("chunker", {})
 
         pipeline_config = PipelineConfig(**pipeline_config_data)
 
@@ -394,4 +400,5 @@ class Pipeline(EmbedChain):
             embedding_model=embedding_model,
             yaml_path=yaml_path,
             auto_deploy=auto_deploy,
+            chunker=chunker_config_data,
         )

+ 5 - 0
embedchain/utils.py

@@ -350,6 +350,11 @@ def validate_yaml_config(config_data):
                     Optional("deployment_name"): str,
                 },
             },
+            Optional("chunker"): {
+                Optional("chunk_size"): int,
+                Optional("chunk_overlap"): int,
+                Optional("length_function"): str,
+            },
         }
     )