Selaa lähdekoodia

[Feature] RSS Feed loader (#942)

Sidharth Mohanty 1 vuosi sitten
vanhempi
commit
d8897ce356

+ 22 - 0
embedchain/chunkers/rss_feed.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class RSSFeedChunker(BaseChunker):
+    """Chunker for RSS Feed."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=2000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 2 - 0
embedchain/data_formatter/data_formatter.py

@@ -72,6 +72,7 @@ class DataFormatter(JSONSerializable):
             DataType.SUBSTACK: "embedchain.loaders.substack.SubstackLoader",
             DataType.YOUTUBE_CHANNEL: "embedchain.loaders.youtube_channel.YoutubeChannelLoader",
             DataType.DISCORD: "embedchain.loaders.discord.DiscordLoader",
+            DataType.RSSFEED: "embedchain.loaders.rss_feed.RSSFeedLoader",
             DataType.BEEHIIV: "embedchain.loaders.beehiiv.BeehiivLoader",
         }
 
@@ -113,6 +114,7 @@ class DataFormatter(JSONSerializable):
             DataType.YOUTUBE_CHANNEL: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.DISCORD: "embedchain.chunkers.common_chunker.CommonChunker",
             DataType.CUSTOM: "embedchain.chunkers.common_chunker.CommonChunker",
+            DataType.RSSFEED: "embedchain.chunkers.rss_feed.RSSFeedChunker",
             DataType.BEEHIIV: "embedchain.chunkers.beehiiv.BeehiivChunker",
         }
 

+ 52 - 0
embedchain/loaders/rss_feed.py

@@ -0,0 +1,52 @@
+import hashlib
+
+from embedchain.helper.json_serializable import register_deserializable
+from embedchain.loaders.base_loader import BaseLoader
+
+
+@register_deserializable
+class RSSFeedLoader(BaseLoader):
+    """Loader for RSS Feed."""
+
+    def load_data(self, url):
+        """Load data from a rss feed."""
+        output = self.get_rss_content(url)
+        doc_id = hashlib.sha256((str(output) + url).encode()).hexdigest()
+        return {
+            "doc_id": doc_id,
+            "data": output,
+        }
+
+    @staticmethod
+    def serialize_metadata(metadata):
+        for key, value in metadata.items():
+            if not isinstance(value, (str, int, float, bool)):
+                metadata[key] = str(value)
+
+        return metadata
+
+    @staticmethod
+    def get_rss_content(url: str):
+        try:
+            from langchain.document_loaders import RSSFeedLoader as LangchainRSSFeedLoader
+        except ImportError:
+            raise ImportError(
+                """RSSFeedLoader file requires extra dependencies.
+                Install with `pip install --upgrade "embedchain[rss_feed]"`"""
+            ) from None
+
+        output = []
+        loader = LangchainRSSFeedLoader(urls=[url])
+        data = loader.load()
+
+        for entry in data:
+            meta_data = RSSFeedLoader.serialize_metadata(entry.metadata)
+            meta_data.update({"url": url})
+            output.append(
+                {
+                    "content": entry.page_content,
+                    "meta_data": meta_data,
+                }
+            )
+
+        return output

+ 2 - 0
embedchain/models/data_type.py

@@ -33,6 +33,7 @@ class IndirectDataType(Enum):
     YOUTUBE_CHANNEL = "youtube_channel"
     DISCORD = "discord"
     CUSTOM = "custom"
+    RSSFEED = "rss_feed"
     BEEHIIV = "beehiiv"
 
 
@@ -66,4 +67,5 @@ class DataType(Enum):
     YOUTUBE_CHANNEL = IndirectDataType.YOUTUBE_CHANNEL.value
     DISCORD = IndirectDataType.DISCORD.value
     CUSTOM = IndirectDataType.CUSTOM.value
+    RSSFEED = IndirectDataType.RSSFEED.value
     BEEHIIV = IndirectDataType.BEEHIIV.value

+ 1 - 2
embedchain/utils.py

@@ -196,8 +196,7 @@ def detect_datatype(source: Any) -> DataType:
     formatted_source = format_source(str(source), 30)
 
     if url:
-        from langchain.document_loaders.youtube import \
-            ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
+        from langchain.document_loaders.youtube import ALLOWED_NETLOCK as YOUTUBE_ALLOWED_NETLOCS
 
         if url.netloc in YOUTUBE_ALLOWED_NETLOCS:
             logging.debug(f"Source of `{formatted_source}` detected as `youtube_video`.")

+ 158 - 13
poetry.lock

@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
 
 [[package]]
 name = "aiofiles"
@@ -1083,6 +1083,17 @@ ssh = ["bcrypt (>=3.1.5)"]
 test = ["pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
 test-randomorder = ["pytest-randomly"]
 
+[[package]]
+name = "cssselect"
+version = "1.2.0"
+description = "cssselect parses CSS3 Selectors and translates them to XPath 1.0"
+optional = true
+python-versions = ">=3.7"
+files = [
+    {file = "cssselect-1.2.0-py2.py3-none-any.whl", hash = "sha256:da1885f0c10b60c03ed5eccbb6b68d6eff248d91976fcde348f395d54c9fd35e"},
+    {file = "cssselect-1.2.0.tar.gz", hash = "sha256:666b19839cfaddb9ce9d36bfe4c969132c647b92fc9088c4e23f786b30f1b3dc"},
+]
+
 [[package]]
 name = "cycler"
 version = "0.12.1"
@@ -1446,6 +1457,35 @@ lz4 = ["lz4"]
 snappy = ["python-snappy"]
 zstandard = ["zstandard"]
 
+[[package]]
+name = "feedfinder2"
+version = "0.0.4"
+description = "Find the feed URLs for a website."
+optional = true
+python-versions = "*"
+files = [
+    {file = "feedfinder2-0.0.4.tar.gz", hash = "sha256:3701ee01a6c85f8b865a049c30ba0b4608858c803fe8e30d1d289fdbe89d0efe"},
+]
+
+[package.dependencies]
+beautifulsoup4 = "*"
+requests = "*"
+six = "*"
+
+[[package]]
+name = "feedparser"
+version = "6.0.10"
+description = "Universal feed parser, handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds"
+optional = true
+python-versions = ">=3.6"
+files = [
+    {file = "feedparser-6.0.10-py3-none-any.whl", hash = "sha256:79c257d526d13b944e965f6095700587f27388e50ea16fd245babe4dfae7024f"},
+    {file = "feedparser-6.0.10.tar.gz", hash = "sha256:27da485f4637ce7163cdeab13a80312b93b7d0c1b775bef4a47629a3110bca51"},
+]
+
+[package.dependencies]
+sgmllib3k = "*"
+
 [[package]]
 name = "filelock"
 version = "3.12.4"
@@ -1737,12 +1777,12 @@ files = [
 google-auth = ">=2.14.1,<3.0.dev0"
 googleapis-common-protos = ">=1.56.2,<2.0.dev0"
 grpcio = [
-    {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
+    {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
 ]
 grpcio-status = [
-    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""},
     {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
+    {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
 ]
 protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0"
 requests = ">=2.18.0,<3.0.0.dev0"
@@ -1830,8 +1870,8 @@ google-api-core = {version = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0dev", extras =
 google-cloud-core = ">=1.6.0,<3.0.0dev"
 google-resumable-media = ">=0.6.0,<3.0dev"
 grpcio = [
-    {version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
     {version = ">=1.49.1,<2.0dev", markers = "python_version >= \"3.11\""},
+    {version = ">=1.47.0,<2.0dev", markers = "python_version < \"3.11\""},
 ]
 packaging = ">=20.0.0"
 proto-plus = ">=1.15.0,<2.0.0dev"
@@ -1882,8 +1922,8 @@ files = [
 google-api-core = {version = ">=1.34.0,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
 grpc-google-iam-v1 = ">=0.12.4,<1.0.0dev"
 proto-plus = [
-    {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
     {version = ">=1.22.2,<2.0.0dev", markers = "python_version >= \"3.11\""},
+    {version = ">=1.22.0,<2.0.0dev", markers = "python_version < \"3.11\""},
 ]
 protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
 
@@ -2598,6 +2638,16 @@ files = [
     {file = "itsdangerous-2.1.2.tar.gz", hash = "sha256:5dbbc68b317e5e42f327f9021763545dc3fc3bfe22e6deb96aaf1fc38874156a"},
 ]
 
+[[package]]
+name = "jieba3k"
+version = "0.35.1"
+description = "Chinese Words Segementation Utilities"
+optional = true
+python-versions = "*"
+files = [
+    {file = "jieba3k-0.35.1.zip", hash = "sha256:980a4f2636b778d312518066be90c7697d410dd5a472385f5afced71a2db1c10"},
+]
+
 [[package]]
 name = "jinja2"
 version = "3.1.2"
@@ -2893,6 +2943,21 @@ ocr = ["google-cloud-vision (==1)", "pytesseract"]
 paddledetection = ["paddlepaddle (==2.1.0)"]
 tesseract = ["pytesseract"]
 
+[[package]]
+name = "listparser"
+version = "0.19"
+description = "Parse OPML subscription lists"
+optional = true
+python-versions = ">=3.7,<4.0"
+files = [
+    {file = "listparser-0.19-py3-none-any.whl", hash = "sha256:c3857a9e5e5342207a556ba72e5c030782971fbe587e7afc2a75b2d7c0fa5a5c"},
+    {file = "listparser-0.19.tar.gz", hash = "sha256:5aa23ae017a22e36c50ca5259a690328dd524527977d8c094ae0857887002805"},
+]
+
+[package.extras]
+http = ["requests (>=2.25.1,<3.0.0)"]
+lxml = ["lxml (>=4.6.2,<5.0.0)"]
+
 [[package]]
 name = "lit"
 version = "17.0.2"
@@ -3505,6 +3570,32 @@ doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-
 extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"]
 test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"]
 
+[[package]]
+name = "newspaper3k"
+version = "0.2.8"
+description = "Simplified python article discovery & extraction."
+optional = true
+python-versions = "*"
+files = [
+    {file = "newspaper3k-0.2.8-py3-none-any.whl", hash = "sha256:44a864222633d3081113d1030615991c3dbba87239f6bbf59d91240f71a22e3e"},
+    {file = "newspaper3k-0.2.8.tar.gz", hash = "sha256:9f1bd3e1fb48f400c715abf875cc7b0a67b7ddcd87f50c9aeeb8fcbbbd9004fb"},
+]
+
+[package.dependencies]
+beautifulsoup4 = ">=4.4.1"
+cssselect = ">=0.9.2"
+feedfinder2 = ">=0.0.4"
+feedparser = ">=5.2.1"
+jieba3k = ">=0.35.1"
+lxml = ">=3.6.0"
+nltk = ">=3.2.1"
+Pillow = ">=3.3.0"
+python-dateutil = ">=2.5.3"
+PyYAML = ">=3.11"
+requests = ">=2.10.0"
+tinysegmenter = "0.3"
+tldextract = ">=2.0.1"
+
 [[package]]
 name = "nltk"
 version = "3.8.1"
@@ -3913,13 +4004,11 @@ files = [
 
 [package.dependencies]
 numpy = [
-    {version = ">=1.21.0", markers = "python_version <= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
-    {version = ">=1.21.2", markers = "python_version >= \"3.10\""},
-    {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
-    {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
-    {version = ">=1.17.0", markers = "python_version >= \"3.7\""},
-    {version = ">=1.17.3", markers = "python_version >= \"3.8\""},
     {version = ">=1.23.5", markers = "python_version >= \"3.11\""},
+    {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""},
+    {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""},
+    {version = ">=1.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version >= \"3.8\" and python_version < \"3.10\" or python_version > \"3.9\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_system != \"Darwin\" and python_version < \"3.10\" or python_version >= \"3.9\" and platform_machine != \"arm64\" and python_version < \"3.10\""},
+    {version = ">=1.21.0", markers = "python_version == \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""},
 ]
 
 [[package]]
@@ -4113,8 +4202,8 @@ files = [
 
 [package.dependencies]
 numpy = [
-    {version = ">=1.22.4", markers = "python_version < \"3.11\""},
     {version = ">=1.23.2", markers = "python_version == \"3.11\""},
+    {version = ">=1.22.4", markers = "python_version < \"3.11\""},
 ]
 python-dateutil = ">=2.8.2"
 pytz = ">=2020.1"
@@ -5617,6 +5706,21 @@ urllib3 = ">=1.21.1,<3"
 socks = ["PySocks (>=1.5.6,!=1.5.7)"]
 use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
 
+[[package]]
+name = "requests-file"
+version = "1.5.1"
+description = "File transport adapter for Requests"
+optional = true
+python-versions = "*"
+files = [
+    {file = "requests-file-1.5.1.tar.gz", hash = "sha256:07d74208d3389d01c38ab89ef403af0cfec63957d53a0081d8eca738d0247d8e"},
+    {file = "requests_file-1.5.1-py2.py3-none-any.whl", hash = "sha256:dfe5dae75c12481f68ba353183c53a65e6044c923e64c24b2209f6c7570ca953"},
+]
+
+[package.dependencies]
+requests = ">=1.0.0"
+six = "*"
+
 [[package]]
 name = "requests-oauthlib"
 version = "1.3.1"
@@ -6044,6 +6148,16 @@ docs = ["entangled-cli[rich]", "mkdocs", "mkdocs-entangled-plugin", "mkdocs-mate
 rich = ["rich"]
 test = ["build", "pytest", "rich", "wheel"]
 
+[[package]]
+name = "sgmllib3k"
+version = "1.0.0"
+description = "Py3k port of sgmllib."
+optional = true
+python-versions = "*"
+files = [
+    {file = "sgmllib3k-1.0.0.tar.gz", hash = "sha256:7868fb1c8bfa764c1ac563d3cf369c381d1325d36124933a726f29fcdaa812e9"},
+]
+
 [[package]]
 name = "shapely"
 version = "2.0.2"
@@ -6230,7 +6344,7 @@ files = [
 ]
 
 [package.dependencies]
-greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""}
+greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
 typing-extensions = ">=4.2.0"
 
 [package.extras]
@@ -6405,6 +6519,36 @@ safetensors = "*"
 torch = ">=1.7"
 torchvision = "*"
 
+[[package]]
+name = "tinysegmenter"
+version = "0.3"
+description = "Very compact Japanese tokenizer"
+optional = true
+python-versions = "*"
+files = [
+    {file = "tinysegmenter-0.3.tar.gz", hash = "sha256:ed1f6d2e806a4758a73be589754384cbadadc7e1a414c81a166fc9adf2d40c6d"},
+]
+
+[[package]]
+name = "tldextract"
+version = "5.1.0"
+description = "Accurately separates a URL's subdomain, domain, and public suffix, using the Public Suffix List (PSL). By default, this includes the public ICANN TLDs and their exceptions. You can optionally support the Public Suffix List's private domains as well."
+optional = true
+python-versions = ">=3.8"
+files = [
+    {file = "tldextract-5.1.0-py3-none-any.whl", hash = "sha256:c8eecb15f556b43db6eebd21667640fb6fba9bc9539b48707432014913a78d13"},
+    {file = "tldextract-5.1.0.tar.gz", hash = "sha256:366acfb099c7eb5dc83545c391d73da6e3afe4eaec652417c3cf13b002a160e1"},
+]
+
+[package.dependencies]
+filelock = ">=3.0.8"
+idna = "*"
+requests = ">=2.1.0"
+requests-file = ">=1.4"
+
+[package.extras]
+testing = ["black", "mypy", "pytest", "pytest-gitignore", "pytest-mock", "responses", "ruff", "tox", "types-filelock", "types-requests"]
+
 [[package]]
 name = "tokenizers"
 version = "0.14.1"
@@ -7688,6 +7832,7 @@ pinecone = ["pinecone-client"]
 poe = ["fastapi-poe"]
 postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
 qdrant = ["qdrant-client"]
+rss-feed = ["feedparser", "listparser", "newspaper3k"]
 slack = ["flask", "slack-sdk"]
 streamlit = []
 vertexai = ["google-cloud-aiplatform"]

+ 4 - 0
pyproject.toml

@@ -137,6 +137,9 @@ mysql-connector-python = { version = "^8.1.0", optional = true }
 gitpython = { version = "^3.1.38", optional = true }
 yt_dlp = { version = "^2023.11.14", optional = true }
 PyGithub = { version = "^1.59.1", optional = true }
+feedparser = { version = "^6.0.10", optional = true }
+newspaper3k = { version = "^0.2.8", optional = true }
+listparser = { version = "^0.19", optional = true }
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
@@ -198,6 +201,7 @@ youtube = [
     "yt_dlp",
     "youtube-transcript-api",
 ]
+rss_feed = ["feedparser", "listparser", "newspaper3k"]
 
 [tool.poetry.group.docs.dependencies]