فهرست منبع

Introduce chunker config in yaml config (#907)

Sidharth Mohanty 1 سال پیش
والد
کامیت
a1de238716
5فایلهای تغییر یافته به همراه42 افزوده شده و 8 حذف شده
  1. 7 2
      embedchain/apps/app.py
  2. 14 1
      embedchain/config/add_config.py
  3. 8 4
      embedchain/embedchain.py
  4. 8 1
      embedchain/pipeline.py
  5. 5 0
      embedchain/utils.py

+ 7 - 2
embedchain/apps/app.py

@@ -2,7 +2,7 @@ from typing import Optional
 
 
 import yaml
 import yaml
 
 
-from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig
+from embedchain.config import AppConfig, BaseEmbedderConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.config.vectordb.base import BaseVectorDbConfig
 from embedchain.embedchain import EmbedChain
 from embedchain.embedchain import EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.base import BaseEmbedder
@@ -38,6 +38,7 @@ class App(EmbedChain):
         embedder: BaseEmbedder = None,
         embedder: BaseEmbedder = None,
         embedder_config: Optional[BaseEmbedderConfig] = None,
         embedder_config: Optional[BaseEmbedderConfig] = None,
         system_prompt: Optional[str] = None,
         system_prompt: Optional[str] = None,
+        chunker: Optional[ChunkerConfig] = None,
     ):
     ):
         """
         """
         Initialize a new `App` instance.
         Initialize a new `App` instance.
@@ -97,6 +98,9 @@ class App(EmbedChain):
         if embedder is None:
         if embedder is None:
             embedder = OpenAIEmbedder(config=embedder_config)
             embedder = OpenAIEmbedder(config=embedder_config)
 
 
+        self.chunker = None
+        if chunker:
+            self.chunker = ChunkerConfig(**chunker)
         # Type check assignments
         # Type check assignments
         if not isinstance(llm, BaseLlm):
         if not isinstance(llm, BaseLlm):
             raise TypeError(
             raise TypeError(
@@ -137,6 +141,7 @@ class App(EmbedChain):
         llm_config_data = config_data.get("llm", {})
         llm_config_data = config_data.get("llm", {})
         db_config_data = config_data.get("vectordb", {})
         db_config_data = config_data.get("vectordb", {})
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
+        chunker_config_data = config_data.get("chunker", {})
 
 
         app_config = AppConfig(**app_config_data.get("config", {}))
         app_config = AppConfig(**app_config_data.get("config", {}))
 
 
@@ -148,4 +153,4 @@ class App(EmbedChain):
 
 
         embedder_provider = embedding_model_config_data.get("provider", "openai")
         embedder_provider = embedding_model_config_data.get("provider", "openai")
         embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
         embedder = EmbedderFactory.create(embedder_provider, embedding_model_config_data.get("config", {}))
-        return cls(config=app_config, llm=llm, db=db, embedder=embedder)
+        return cls(config=app_config, llm=llm, db=db, embedder=embedder, chunker=chunker_config_data)

+ 14 - 1
embedchain/config/add_config.py

@@ -1,3 +1,5 @@
+import builtins
+from importlib import import_module
 from typing import Callable, Optional
 from typing import Callable, Optional
 
 
 from embedchain.config.base_config import BaseConfig
 from embedchain.config.base_config import BaseConfig
@@ -18,7 +20,18 @@ class ChunkerConfig(BaseConfig):
     ):
     ):
         self.chunk_size = chunk_size if chunk_size else 2000
         self.chunk_size = chunk_size if chunk_size else 2000
         self.chunk_overlap = chunk_overlap if chunk_overlap else 0
         self.chunk_overlap = chunk_overlap if chunk_overlap else 0
-        self.length_function = length_function if length_function else len
+        if isinstance(length_function, str):
+            self.length_function = self.load_func(length_function)
+        else:
+            self.length_function = length_function if length_function else len
+
+    def load_func(self, dotpath: str):
+        if "." not in dotpath:
+            return getattr(builtins, dotpath)
+        else:
+            module_, func = dotpath.rsplit(".", maxsplit=1)
+            m = import_module(module_)
+            return getattr(m, func)
 
 
 
 
 @register_deserializable
 @register_deserializable

+ 8 - 4
embedchain/embedchain.py

@@ -10,15 +10,14 @@ from dotenv import load_dotenv
 from langchain.docstore.document import Document
 from langchain.docstore.document import Document
 
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.chunkers.base_chunker import BaseChunker
-from embedchain.config import AddConfig, BaseLlmConfig
+from embedchain.config import AddConfig, BaseLlmConfig, ChunkerConfig
 from embedchain.config.apps.base_app_config import BaseAppConfig
 from embedchain.config.apps.base_app_config import BaseAppConfig
 from embedchain.data_formatter import DataFormatter
 from embedchain.data_formatter import DataFormatter
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.helper.json_serializable import JSONSerializable
 from embedchain.llm.base import BaseLlm
 from embedchain.llm.base import BaseLlm
 from embedchain.loaders.base_loader import BaseLoader
 from embedchain.loaders.base_loader import BaseLoader
-from embedchain.models.data_type import (DataType, DirectDataType,
-                                         IndirectDataType, SpecialDataType)
+from embedchain.models.data_type import DataType, DirectDataType, IndirectDataType, SpecialDataType
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.telemetry.posthog import AnonymousTelemetry
 from embedchain.utils import detect_datatype, is_valid_json_string
 from embedchain.utils import detect_datatype, is_valid_json_string
 from embedchain.vectordb.base import BaseVectorDB
 from embedchain.vectordb.base import BaseVectorDB
@@ -84,6 +83,7 @@ class EmbedChain(JSONSerializable):
         # Attributes that aren't subclass related.
         # Attributes that aren't subclass related.
         self.user_asks = []
         self.user_asks = []
 
 
+        self.chunker: ChunkerConfig = None
         # Send anonymous telemetry
         # Send anonymous telemetry
         self._telemetry_props = {"class": self.__class__.__name__}
         self._telemetry_props = {"class": self.__class__.__name__}
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
         self.telemetry = AnonymousTelemetry(enabled=self.config.collect_metrics)
@@ -157,7 +157,11 @@ class EmbedChain(JSONSerializable):
         :return: source_hash, a md5-hash of the source, in hexadecimal representation.
         :return: source_hash, a md5-hash of the source, in hexadecimal representation.
         :rtype: str
         :rtype: str
         """
         """
-        if config is None:
+        if config is not None:
+            pass
+        elif self.chunker is not None:
+            config = AddConfig(chunker=self.chunker)
+        else:
             config = AddConfig()
             config = AddConfig()
 
 
         try:
         try:

+ 8 - 1
embedchain/pipeline.py

@@ -9,7 +9,7 @@ import requests
 import yaml
 import yaml
 
 
 from embedchain import Client
 from embedchain import Client
-from embedchain.config import PipelineConfig
+from embedchain.config import PipelineConfig, ChunkerConfig
 from embedchain.embedchain import CONFIG_DIR, EmbedChain
 from embedchain.embedchain import CONFIG_DIR, EmbedChain
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.base import BaseEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
 from embedchain.embedder.openai import OpenAIEmbedder
@@ -44,6 +44,7 @@ class Pipeline(EmbedChain):
         yaml_path: str = None,
         yaml_path: str = None,
         log_level=logging.INFO,
         log_level=logging.INFO,
         auto_deploy: bool = False,
         auto_deploy: bool = False,
+        chunker: ChunkerConfig = None,
     ):
     ):
         """
         """
         Initialize a new `App` instance.
         Initialize a new `App` instance.
@@ -84,6 +85,10 @@ class Pipeline(EmbedChain):
         # pipeline_id from the backend
         # pipeline_id from the backend
         self.id = None
         self.id = None
 
 
+        self.chunker = None
+        if chunker:
+            self.chunker = ChunkerConfig(**chunker)
+
         self.config = config or PipelineConfig()
         self.config = config or PipelineConfig()
         self.name = self.config.name
         self.name = self.config.name
 
 
@@ -366,6 +371,7 @@ class Pipeline(EmbedChain):
         db_config_data = config_data.get("vectordb", {})
         db_config_data = config_data.get("vectordb", {})
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         embedding_model_config_data = config_data.get("embedding_model", config_data.get("embedder", {}))
         llm_config_data = config_data.get("llm", {})
         llm_config_data = config_data.get("llm", {})
+        chunker_config_data = config_data.get("chunker", {})
 
 
         pipeline_config = PipelineConfig(**pipeline_config_data)
         pipeline_config = PipelineConfig(**pipeline_config_data)
 
 
@@ -394,4 +400,5 @@ class Pipeline(EmbedChain):
             embedding_model=embedding_model,
             embedding_model=embedding_model,
             yaml_path=yaml_path,
             yaml_path=yaml_path,
             auto_deploy=auto_deploy,
             auto_deploy=auto_deploy,
+            chunker=chunker_config_data,
         )
         )

+ 5 - 0
embedchain/utils.py

@@ -350,6 +350,11 @@ def validate_yaml_config(config_data):
                     Optional("deployment_name"): str,
                     Optional("deployment_name"): str,
                 },
                 },
             },
             },
+            Optional("chunker"): {
+                Optional("chunk_size"): int,
+                Optional("chunk_overlap"): int,
+                Optional("length_function"): str,
+            },
         }
         }
     )
     )