Jelajahi Sumber

[Feature] Add Postgres data loader (#918)

Co-authored-by: Deven Patel <deven298@yahoo.com>
Deven Patel 1 tahun lalu
induk
melakukan
7de8d85199

+ 1 - 0
docs/data-sources/overview.mdx

@@ -20,6 +20,7 @@ Embedchain comes with built-in support for various data sources. We handle the c
   <Card title="🙌 OpenAPI" href="/data-sources/openapi"></Card>
   <Card title="🎥📺 youtube video" href="/data-sources/youtube-video"></Card>
   <Card title="📬 Gmail" href="/data-sources/gmail"></Card>
+  <Card title="🐘 Postgres" href="/data-sources/postgres"></Card>
 </CardGroup>
 
 <br/ >

+ 64 - 0
docs/data-sources/postgres.mdx

@@ -0,0 +1,64 @@
+---
+title: '🐘 Postgres'
+---
+
+1. Setup the Postgres loader by configuring the postgres db.
+```Python
+from embedchain.loaders.postgres import PostgresLoader
+
+config = {
+    "host": "host_address",
+    "port": "port_number",
+    "dbname": "database_name",
+    "user": "username",
+    "password": "password",
+}
+
+"""
+config = {
+    "url": "your_postgres_url"
+}
+"""
+
+postgres_loader = PostgresLoader(config=config)
+
+```
+
+You can either setup the loader by passing the postgresql url or by providing the config data.
+For more details on how to setup with valid url and config, check postgres [documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING:~:text=34.1.1.%C2%A0Connection%20Strings-,%23,-Several%20libpq%20functions).
+
+NOTE: if you provide the `url` field in config, all other fields will be ignored.
+
+2. Once you setup the loader, you can create an app and load data using the above postgres loader
+```Python
+import os
+from embedchain.pipeline import Pipeline as App
+
+os.environ["OPENAI_API_KEY"] = "sk-xxx"
+
+app = App()
+
+question = "What is Elon Musk's networth?"
+response = app.query(question)
+# Answer: As of September 2021, Elon Musk's net worth is estimated to be around $250 billion, making him one of the wealthiest individuals in the world. However, please note that net worth can fluctuate over time due to various factors such as stock market changes and business ventures.
+
+app.add("SELECT * FROM table_name;", data_type='postgres', loader=postgres_loader)
+# Adds `(1, 'What is your net worth, Elon Musk?', "As of October 2023, Elon Musk's net worth is $255.2 billion.")`
+
+response = app.query(question)
+# Answer: As of October 2023, Elon Musk's net worth is $255.2 billion.
+```
+
+NOTE: The `add` function of the app will accept any executable query to load data. DO NOT pass the `CREATE`, `INSERT` queries in `add` function as they will result in not adding any data, so it is pointless.
+
+3. We automatically create a chunker to chunk your postgres data, however if you wish to provide your own chunker class. Here is how you can do that:
+```Python
+
+from embedchain.chunkers.postgres import PostgresChunker
+from embedchain.config.add_config import ChunkerConfig
+
+postgres_chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+postgres_chunker = PostgresChunker(config=postgres_chunker_config)
+
+app.add("SELECT * FROM table_name;", data_type='postgres', loader=postgres_loader, chunker=postgres_chunker)
+```

+ 22 - 0
embedchain/chunkers/postgres.py

@@ -0,0 +1,22 @@
+from typing import Optional
+
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+from embedchain.chunkers.base_chunker import BaseChunker
+from embedchain.config.add_config import ChunkerConfig
+from embedchain.helper.json_serializable import register_deserializable
+
+
+@register_deserializable
+class PostgresChunker(BaseChunker):
+    """Chunker for postgres."""
+
+    def __init__(self, config: Optional[ChunkerConfig] = None):
+        if config is None:
+            config = ChunkerConfig(chunk_size=1000, chunk_overlap=0, length_function=len)
+        text_splitter = RecursiveCharacterTextSplitter(
+            chunk_size=config.chunk_size,
+            chunk_overlap=config.chunk_overlap,
+            length_function=config.length_function,
+        )
+        super().__init__(text_splitter)

+ 34 - 9
embedchain/data_formatter/data_formatter.py

@@ -1,4 +1,5 @@
 from importlib import import_module
+from typing import Any, Dict
 
 from embedchain.chunkers.base_chunker import BaseChunker
 from embedchain.config import AddConfig
@@ -15,7 +16,7 @@ class DataFormatter(JSONSerializable):
     .add or .add_local method call
     """
 
-    def __init__(self, data_type: DataType, config: AddConfig):
+    def __init__(self, data_type: DataType, config: AddConfig, kwargs: Dict[str, Any]):
         """
         Initialize a dataformatter, set data type and chunker based on datatype.
 
@@ -24,15 +25,15 @@ class DataFormatter(JSONSerializable):
         :param config: AddConfig instance with nested loader and chunker config attributes.
         :type config: AddConfig
         """
-        self.loader = self._get_loader(data_type=data_type, config=config.loader)
-        self.chunker = self._get_chunker(data_type=data_type, config=config.chunker)
+        self.loader = self._get_loader(data_type=data_type, config=config.loader, kwargs=kwargs)
+        self.chunker = self._get_chunker(data_type=data_type, config=config.chunker, kwargs=kwargs)
 
     def _lazy_load(self, module_path: str):
         module_path, class_name = module_path.rsplit(".", 1)
         module = import_module(module_path)
         return getattr(module, class_name)
 
-    def _get_loader(self, data_type: DataType, config: LoaderConfig) -> BaseLoader:
+    def _get_loader(self, data_type: DataType, config: LoaderConfig, kwargs: Dict[str, Any]) -> BaseLoader:
         """
         Returns the appropriate data loader for the given data type.
 
@@ -63,13 +64,28 @@ class DataFormatter(JSONSerializable):
             DataType.GMAIL: "embedchain.loaders.gmail.GmailLoader",
             DataType.NOTION: "embedchain.loaders.notion.NotionLoader",
         }
+
+        custom_loaders = set(
+            [
+                DataType.POSTGRES,
+            ]
+        )
+
         if data_type in loaders:
             loader_class: type = self._lazy_load(loaders[data_type])
             return loader_class()
-        else:
-            raise ValueError(f"Unsupported data type: {data_type}")
+        elif data_type in custom_loaders:
+            loader_class: type = kwargs.get("loader", None)
+            if loader_class is not None:
+                return loader_class
 
-    def _get_chunker(self, data_type: DataType, config: ChunkerConfig) -> BaseChunker:
+        raise ValueError(
+            f"Cant find the loader for {data_type}.\
+                    We recommend to pass the loader to use data_type: {data_type},\
+                        check `https://docs.embedchain.ai/data-sources/overview`."
+        )
+
+    def _get_chunker(self, data_type: DataType, config: ChunkerConfig, kwargs: Dict[str, Any]) -> BaseChunker:
         """Returns the appropriate chunker for the given data type (updated for lazy loading)."""
         chunker_classes = {
             DataType.YOUTUBE_VIDEO: "embedchain.chunkers.youtube_video.YoutubeVideoChunker",
@@ -89,12 +105,21 @@ class DataFormatter(JSONSerializable):
             DataType.OPENAPI: "embedchain.chunkers.openapi.OpenAPIChunker",
             DataType.GMAIL: "embedchain.chunkers.gmail.GmailChunker",
             DataType.NOTION: "embedchain.chunkers.notion.NotionChunker",
+            DataType.POSTGRES: "embedchain.chunkers.postgres.PostgresChunker",
         }
 
         if data_type in chunker_classes:
-            chunker_class = self._lazy_load(chunker_classes[data_type])
+            if "chunker" in kwargs:
+                chunker_class = kwargs.get("chunker")
+            else:
+                chunker_class = self._lazy_load(chunker_classes[data_type])
+
             chunker = chunker_class(config)
             chunker.set_data_type(data_type)
             return chunker
         else:
-            raise ValueError(f"Unsupported data type: {data_type}")
+            raise ValueError(
+                f"Cant find the chunker for {data_type}.\
+                    We recommend to pass the chunker to use data_type: {data_type},\
+                        check `https://docs.embedchain.ai/data-sources/overview`."
+            )

+ 11 - 17
embedchain/embedchain.py

@@ -137,6 +137,7 @@ class EmbedChain(JSONSerializable):
         metadata: Optional[Dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
         dry_run=False,
+        **kwargs: Dict[str, Any],
     ):
         """
         Adds the data from the given URL to the vector db.
@@ -180,21 +181,6 @@ class EmbedChain(JSONSerializable):
         if data_type:
             try:
                 data_type = DataType(data_type)
-                if data_type == DataType.JSON:
-                    if isinstance(source, str):
-                        if not is_valid_json_string(source):
-                            raise ValueError(
-                                f"Invalid json input: {source}",
-                                "Provide the correct JSON formatted source, \
-                                    refer `https://docs.embedchain.ai/data-sources/json`",
-                            )
-                    elif not isinstance(source, str):
-                        raise ValueError(
-                            "Invaid content input. \
-                            If you want to upload (list, dict, etc.), do \
-                                `json.dump(data, indent=0)` and add the stringified JSON. \
-                                    Check - `https://docs.embedchain.ai/data-sources/json`"
-                        )
             except ValueError:
                 raise ValueError(
                     f"Invalid data_type: '{data_type}'.",
@@ -218,8 +204,9 @@ class EmbedChain(JSONSerializable):
             print(f"Data with hash {source_hash} already exists. Skipping addition.")
             return source_hash
 
-        data_formatter = DataFormatter(data_type, config)
         self.user_asks.append([source, data_type.value, metadata])
+
+        data_formatter = DataFormatter(data_type, config, kwargs)
         documents, metadatas, _ids, new_chunks = self.load_and_embed(
             data_formatter.loader, data_formatter.chunker, source, metadata, source_hash, dry_run
         )
@@ -265,6 +252,7 @@ class EmbedChain(JSONSerializable):
         data_type: Optional[DataType] = None,
         metadata: Optional[Dict[str, Any]] = None,
         config: Optional[AddConfig] = None,
+        **kwargs: Dict[str, Any],
     ):
         """
         Adds the data from the given URL to the vector db.
@@ -290,7 +278,13 @@ class EmbedChain(JSONSerializable):
         logging.warning(
             "The `add_local` method is deprecated and will be removed in future versions. Please use the `add` method for both local and remote files."  # noqa: E501
         )
-        return self.add(source=source, data_type=data_type, metadata=metadata, config=config)
+        return self.add(
+            source=source,
+            data_type=data_type,
+            metadata=metadata,
+            config=config,
+            kwargs=kwargs,
+        )
 
     def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
         """

+ 11 - 0
embedchain/loaders/json.py

@@ -25,10 +25,21 @@ class JSONLoader(BaseLoader):
 
         return LLHUBJSONLoader()
 
+    @staticmethod
+    def _check_content(content):
+        if not isinstance(content, str):
+            raise ValueError(
+                "Invaid content input. \
+                If you want to upload (list, dict, etc.), do \
+                    `json.dump(data, indent=0)` and add the stringified JSON. \
+                        Check - `https://docs.embedchain.ai/data-sources/json`"
+            )
+
     @staticmethod
     def load_data(content):
         """Load a json file. Each data point is a key value pair."""
 
+        JSONLoader._check_content(content)
         loader = JSONLoader._get_llama_hub_loader()
 
         data = []

+ 73 - 0
embedchain/loaders/postgres.py

@@ -0,0 +1,73 @@
+import hashlib
+import logging
+from typing import Any, Dict, Optional
+
+from embedchain.loaders.base_loader import BaseLoader
+
+
+class PostgresLoader(BaseLoader):
+    def __init__(self, config: Optional[Dict[str, Any]] = None):
+        super().__init__()
+        if not config:
+            raise ValueError(f"Must provide the valid config. Received: {config}")
+
+        self.connection = None
+        self.cursor = None
+        self._setup_loader(config=config)
+
+    def _setup_loader(self, config: Dict[str, Any]):
+        try:
+            import psycopg
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import required packages. \
+                    Run `pip install --upgrade 'embedchain[postgres]'`"
+            ) from e
+
+        config_info = ""
+        if "url" in config:
+            config_info = config.get("url")
+        else:
+            conn_params = []
+            for key, value in config.items():
+                conn_params.append(f"{key}={value}")
+            config_info = " ".join(conn_params)
+
+        logging.info(f"Connecting to postrgres sql: {config_info}")
+        self.connection = psycopg.connect(conninfo=config_info)
+        self.cursor = self.connection.cursor()
+
+    def _check_query(self, query):
+        if not isinstance(query, str):
+            raise ValueError(
+                f"Invalid postgres query: {query}",
+                "Provide the valid source to add from postgres, \
+                    make sure you are following `https://docs.embedchain.ai/data-sources/postgres`",
+            )
+
+    def load_data(self, query):
+        self._check_query(query)
+        try:
+            data = []
+            data_content = []
+            self.cursor.execute(query)
+            results = self.cursor.fetchall()
+            for result in results:
+                doc_content = str(result)
+                data.append({"content": doc_content, "meta_data": {"url": f"postgres_query-({query})"}})
+                data_content.append(doc_content)
+            doc_id = hashlib.sha256((query + ", ".join(data_content)).encode()).hexdigest()
+            return {
+                "doc_id": doc_id,
+                "data": data,
+            }
+        except Exception as e:
+            raise ValueError(f"Failed to load data using query={query} with: {e}")
+
+    def close_connection(self):
+        if self.cursor:
+            self.cursor.close()
+            self.cursor = None
+        if self.connection:
+            self.connection.close()
+            self.connection = None

+ 2 - 0
embedchain/models/data_type.py

@@ -29,6 +29,7 @@ class IndirectDataType(Enum):
     JSON = "json"
     OPENAPI = "openapi"
     GMAIL = "gmail"
+    POSTGRES = "postgres"
 
 
 class SpecialDataType(Enum):
@@ -57,3 +58,4 @@ class DataType(Enum):
     JSON = IndirectDataType.JSON.value
     OPENAPI = IndirectDataType.OPENAPI.value
     GMAIL = IndirectDataType.GMAIL.value
+    POSTGRES = IndirectDataType.POSTGRES.value

+ 4 - 0
pyproject.toml

@@ -130,6 +130,9 @@ pymilvus = { version = "2.3.1", optional = true }
 google-cloud-aiplatform = { version = "^1.26.1", optional = true }
 replicate = { version = "^0.15.4", optional = true }
 schema = "^0.7.5"
+psycopg = { version = "^3.1.12", optional = true }
+psycopg-binary = { version = "^3.1.12", optional = true }
+psycopg-pool = { version = "^3.1.8", optional = true }
 
 [tool.poetry.group.dev.dependencies]
 black = "^23.3.0"
@@ -184,6 +187,7 @@ gmail = [
     "google-api-core",
 ]
 json = ["llama-hub"]
+postgres = ["psycopg", "psycopg-binary", "psycopg-pool"]
 
 [tool.poetry.group.docs.dependencies]
 

+ 2 - 0
tests/chunkers/test_chunkers.py

@@ -6,6 +6,7 @@ from embedchain.chunkers.mdx import MdxChunker
 from embedchain.chunkers.notion import NotionChunker
 from embedchain.chunkers.openapi import OpenAPIChunker
 from embedchain.chunkers.pdf_file import PdfFileChunker
+from embedchain.chunkers.postgres import PostgresChunker
 from embedchain.chunkers.qna_pair import QnaPairChunker
 from embedchain.chunkers.sitemap import SitemapChunker
 from embedchain.chunkers.table import TableChunker
@@ -33,6 +34,7 @@ chunker_common_config = {
     JSONChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     OpenAPIChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
     GmailChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
+    PostgresChunker: {"chunk_size": 1000, "chunk_overlap": 0, "length_function": len},
 }
 
 

+ 1 - 1
tests/embedchain/test_embedchain.py

@@ -63,5 +63,5 @@ def test_add_after_reset(app_instance, mocker):
 def test_add_with_incorrect_content(app_instance, mocker):
     content = [{"foo": "bar"}]
 
-    with pytest.raises(ValueError):
+    with pytest.raises(TypeError):
         app_instance.add(content, data_type="json")

+ 60 - 0
tests/loaders/test_postgres.py

@@ -0,0 +1,60 @@
+from unittest.mock import MagicMock
+
+import psycopg
+import pytest
+
+from embedchain.loaders.postgres import PostgresLoader
+
+
+@pytest.fixture
+def postgres_loader(mocker):
+    with mocker.patch.object(psycopg, "connect"):
+        config = {"url": "postgres://user:password@localhost:5432/database"}
+        loader = PostgresLoader(config=config)
+        yield loader
+
+
+def test_postgres_loader_initialization(postgres_loader):
+    assert postgres_loader.connection is not None
+    assert postgres_loader.cursor is not None
+
+
+def test_postgres_loader_invalid_config():
+    with pytest.raises(ValueError, match="Must provide the valid config. Received: None"):
+        PostgresLoader(config=None)
+
+
+def test_load_data(postgres_loader, monkeypatch):
+    mock_cursor = MagicMock()
+    monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
+
+    query = "SELECT * FROM table"
+    mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]
+
+    result = postgres_loader.load_data(query)
+
+    assert "doc_id" in result
+    assert "data" in result
+    assert len(result["data"]) == 2
+    assert result["data"][0]["meta_data"]["url"] == f"postgres_query-({query})"
+    assert result["data"][1]["meta_data"]["url"] == f"postgres_query-({query})"
+    assert mock_cursor.execute.called_with(query)
+
+
+def test_load_data_exception(postgres_loader, monkeypatch):
+    mock_cursor = MagicMock()
+    monkeypatch.setattr(postgres_loader, "cursor", mock_cursor)
+
+    _ = "SELECT * FROM table"
+    mock_cursor.execute.side_effect = Exception("Mocked exception")
+
+    with pytest.raises(
+        ValueError, match=r"Failed to load data using query=SELECT \* FROM table with: Mocked exception"
+    ):
+        postgres_loader.load_data("SELECT * FROM table")
+
+
+def test_close_connection(postgres_loader):
+    postgres_loader.close_connection()
+    assert postgres_loader.cursor is None
+    assert postgres_loader.connection is None