Browse Source

[Improvement] customize add method (#988)

Deven Patel 1 year ago
parent
commit
512cfc9466

+ 41 - 0
docs/data-sources/custom.mdx

@@ -0,0 +1,41 @@
+---
+title: '⚙️ Custom'
+---
+
+When we say "custom", we mean that you can customize the loader and chunker to your needs. This is done by passing a custom loader and chunker to the `add` method.
+
+```python
+from embedchain import Pipeline as App
+import your_loader
+import your_chunker
+
+app = App()
+loader = your_loader()
+chunker = your_chunker()
+
+app.add("source", data_type="custom", loader=loader, chunker=chunker)
+```
+
+<Note>
+    The custom loader and chunker must be a class that inherits from the [`BaseLoader`](https://github.com/embedchain/embedchain/blob/main/embedchain/loaders/base_loader.py) and [`BaseChunker`](https://github.com/embedchain/embedchain/blob/main/embedchain/chunkers/base_chunker.py) classes respectively.
+</Note>
+
+<Note>
+    If the `data_type` is not a valid data type, the `add` method will fallback to the `custom` data type and expect a custom loader and chunker to be passed by the user.
+</Note>
+
+Example:
+
+```python
+from embedchain import Pipeline as App
+from embedchain.loaders.github import GithubLoader
+
+app = App()
+
+loader = GithubLoader(config={"token": "ghp_xxx"})
+
+app.add("repo:embedchain/embedchain type:repo", data_type="github", loader=loader)
+
+app.query("What is Embedchain?")
+# Answer: Embedchain is a Data Platform for Large Language Models (LLMs). It allows users to seamlessly load, index, retrieve, and sync unstructured data in order to build dynamic, LLM-powered applications. There is also a JavaScript implementation called embedchain-js available on GitHub.
+```

+ 1 - 0
docs/data-sources/overview.mdx

@@ -26,6 +26,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
   <Card title="🗨️ Discourse" href="/data-sources/discourse"></Card>
   <Card title="💬 Discord" href="/data-sources/discord"></Card>
   <Card title="📝 Github" href="/data-sources/github"></Card>
+  <Card title="⚙️ Custom" href="/data-sources/custom"></Card>
 </CardGroup>
 
 <br/ >

+ 1 - 1
embedchain/chunkers/common_chunker.py

@@ -13,7 +13,7 @@ class CommonChunker(BaseChunker):
 
     def __init__(self, config: Optional[ChunkerConfig] = None):
         if config is None:
-            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+            config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
         text_splitter = RecursiveCharacterTextSplitter(
             chunk_size=config.chunk_size,
             chunk_overlap=config.chunk_overlap,

+ 20 - 32
embedchain/data_formatter/data_formatter.py

@@ -68,23 +68,13 @@ class DataFormatter(JSONSerializable):
             DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
         }
 
-        custom_loaders = set(
-            [
-                DataType.POSTGRES,
-                DataType.MYSQL,
-                DataType.SLACK,
-                DataType.DISCOURSE,
-                DataType.GITHUB,
-            ]
-        )
-
-        if data_type in loaders:
-            loader_class: type = self._lazy_load(loaders[data_type])
-            return loader_class()
-        elif data_type in custom_loaders:
+        if data_type == DataType.CUSTOM or ("loader" in kwargs):
             loader_class: type = kwargs.get("loader", None)
-            if loader_class is not None:
+            if loader_class:
                 return loader_class
+        elif data_type in loaders:
+            loader_class: type = self._lazy_load(loaders[data_type])
+            return loader_class()
 
         raise ValueError(
             f"Cant find the loader for {data_type}.\
@@ -112,28 +102,26 @@ class DataFormatter(JSONSerializable):
             DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
             DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
             DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
-            DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
-            DataType.MYSQL: "embedchain.chunkers.mysql.MySQLChunker",
-            DataType.SLACK: "embedchain.chunkers.slack.SlackChunker",
-            DataType.DISCOURSE: "embedchain.chunkers.discourse.DiscourseChunker",
             DataType.SUBSTACK: "embedchain.chunkers.substack.SubstackChunker",
-            DataType.GITHUB: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.DISCORD: "embedchain.chunkers.common_chunker.CommonChunker",
+            DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
         }
 
-        if data_type in chunker_classes:
-            if "chunker" in kwargs:
-                chunker_class = kwargs.get("chunker")
-            else:
-                chunker_class = self._lazy_load(chunker_classes[data_type])
-
+        if "chunker" in kwargs:
+            chunker_class = kwargs.get("chunker", None)
+            if chunker_class:
+                chunker = chunker_class(config)
+                chunker.set_data_type(data_type)
+                return chunker
+        elif data_type in chunker_classes:
+            chunker_class = self._lazy_load(chunker_classes[data_type])
             chunker = chunker_class(config)
             chunker.set_data_type(data_type)
             return chunker
-        else:
-            raise ValueError(
-                f"Cant find the chunker for {data_type}.\
-                    We recommend to pass the chunker to use data_type: {data_type},\
-                        check `https://docs.embedchain.ai/data-sources/overview`."
-            )
+
+        raise ValueError(
+            f"Cant find the chunker for {data_type}.\
+                We recommend to pass the chunker to use data_type: {data_type},\
+                    check `https://docs.embedchain.ai/data-sources/overview`."
+        )

+ 4 - 4
embedchain/embedchain.py

@@ -178,10 +178,10 @@ class EmbedChain(JSONSerializable):
             try:
                 data_type = DataType(data_type)
             except ValueError:
-                raise ValueError(
-                    f"Invalid data_type: '{data_type}'.",
-                    f"Please use one of the following: {[data_type.value for data_type in DataType]}",
-                ) from None
+                logging.info(
+                    f"Invalid data_type: '{data_type}', using `custom` instead.\n Check docs to pass the valid data type: `https://docs.embedchain.ai/data-sources/overview`"  # noqa: E501
+                )
+                data_type = DataType.CUSTOM
 
         if not data_type:
             data_type = detect_datatype(source)

+ 2 - 10
embedchain/models/data_type.py

@@ -29,14 +29,10 @@ class IndirectDataType(Enum):
     JSON = "json"
     OPENAPI = "openapi"
     GMAIL = "gmail"
-    POSTGRES = "postgres"
-    MYSQL = "mysql"
-    SLACK = "slack"
-    DISCOURSE = "discourse"
     SUBSTACK = "substack"
-    GITHUB = "github"
     YOUTUBE_CHANNEL = "youtube_channel"
     DISCORD = "discord"
+    CUSTOM = "custom"
 
 
 class SpecialDataType(Enum):
@@ -65,11 +61,7 @@ class DataType(Enum):
     JSON = IndirectDataType.JSON.value
     OPENAPI = IndirectDataType.OPENAPI.value
     GMAIL = IndirectDataType.GMAIL.value
-    POSTGRES = IndirectDataType.POSTGRES.value
-    MYSQL = IndirectDataType.MYSQL.value
-    SLACK = IndirectDataType.SLACK.value
-    DISCOURSE = IndirectDataType.DISCOURSE.value
     SUBSTACK = IndirectDataType.SUBSTACK.value
-    GITHUB = IndirectDataType.GITHUB.value
     YOUTUBE_CHANNEL = IndirectDataType.YOUTUBE_CHANNEL.value
     DISCORD = IndirectDataType.DISCORD.value
+    CUSTOM = IndirectDataType.CUSTOM.value

+ 1 - 1
tests/chunkers/test_chunkers.py

@@ -40,7 +40,7 @@ chunker_common_config = {
     PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     SlackChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     DiscourseChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
-    CommonChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    CommonChunker: {"chunk_size": 2000, "chunk_overlap": 0, "length_function": len},
 }