瀏覽代碼

Fix CI/CD (#1492)

Deshraj Yadav 1 年之前
父節點
當前提交
fb5a3bfd95

+ 11 - 5
.github/workflows/cd.yml

@@ -2,14 +2,13 @@ name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI
 
 on:
   release:
-    types: [published] # This will trigger the workflow when you create a new release
+    types: [published]
 
 jobs:
   build-n-publish:
     name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI
     runs-on: ubuntu-latest
     permissions:
-      # IMPORTANT: this permission is mandatory for trusted publishing
       id-token: write
     steps:
       - uses: actions/checkout@v2
@@ -25,16 +24,23 @@ jobs:
           echo "$HOME/.local/bin" >> $GITHUB_PATH
 
       - name: Install dependencies
-        run: poetry install
+        run: |
+          cd embedchain
+          poetry install
 
       - name: Build a binary wheel and a source tarball
-        run: poetry build
+        run: |
+          cd embedchain
+          poetry build
 
       - name: Publish distribution 📦 to Test PyPI
         uses: pypa/gh-action-pypi-publish@release/v1
         with:
           repository_url: https://test.pypi.org/legacy/
+          packages_dir: embedchain/dist/
 
       - name: Publish distribution 📦 to PyPI
         if: startsWith(github.ref, 'refs/tags')
-        uses: pypa/gh-action-pypi-publish@release/v1
+        uses: pypa/gh-action-pypi-publish@release/v1
+        with:
+          packages_dir: embedchain/dist/

+ 8 - 8
.github/workflows/ci.yml

@@ -5,13 +5,13 @@ on:
     branches: [main]
     paths:
       - 'embedchain/**'
-      - 'tests/**'
-      - 'examples/**'
+      - 'embedchain/tests/**'
+      - 'embedchain/examples/**'
   pull_request:
     paths:
-      - 'embedchain/**'
-      - 'tests/**'
-      - 'examples/**'
+      - 'embedchain/embedchain/**'
+      - 'embedchain/tests/**'
+      - 'embedchain/examples/**'
 
 jobs:
   build:
@@ -39,12 +39,12 @@ jobs:
           path: .venv
           key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
       - name: Install dependencies
-        run: make install_all
+        run: cd embedchain && make install_all
         if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
       - name: Lint with ruff
-        run: make lint
+        run: cd embedchain && make lint
       - name: Run tests and generate coverage report
-        run: make coverage
+        run: cd embedchain && make coverage
       - name: Upload coverage reports to Codecov
         uses: codecov/codecov-action@v3
         with:

+ 2 - 2
.gitignore

@@ -165,7 +165,7 @@ cython_debug/
 # Database
 db
 test-db
-!embedchain/core/db/
+!embedchain/embedchain/core/db/
 
 .vscode
 .idea/
@@ -183,4 +183,4 @@ notebooks/*.yaml
 # local directories for testing
 eval/
 qdrant_storage/
-.crossnote
+.crossnote

+ 1 - 1
embedchain/embedchain/config/llm/base.py

@@ -234,7 +234,7 @@ class BaseLlmConfig(BaseConfig):
         self.api_version = api_version
 
         if token_usage:
-            f = open("config/model_prices_and_context_window.json")
+            f = open("embedchain/config/model_prices_and_context_window.json")
             self.model_pricing_map = json.load(f)
 
         if isinstance(prompt, str):

+ 0 - 0
embedchain/embedchain/core/db/__init__.py


+ 88 - 0
embedchain/embedchain/core/db/database.py

@@ -0,0 +1,88 @@
+import os
+
+from alembic import command
+from alembic.config import Config
+from sqlalchemy import create_engine
+from sqlalchemy.engine.base import Engine
+from sqlalchemy.orm import Session as SQLAlchemySession
+from sqlalchemy.orm import scoped_session, sessionmaker
+
+from .models import Base
+
+
+class DatabaseManager:
+    def __init__(self, echo: bool = False):
+        self.database_uri = os.environ.get("EMBEDCHAIN_DB_URI")
+        self.echo = echo
+        self.engine: Engine = None
+        self._session_factory = None
+
+    def setup_engine(self) -> None:
+        """Initializes the database engine and session factory."""
+        if not self.database_uri:
+            raise RuntimeError("Database URI is not set. Set the EMBEDCHAIN_DB_URI environment variable.")
+        connect_args = {}
+        if self.database_uri.startswith("sqlite"):
+            connect_args["check_same_thread"] = False
+        self.engine = create_engine(self.database_uri, echo=self.echo, connect_args=connect_args)
+        self._session_factory = scoped_session(sessionmaker(bind=self.engine))
+        Base.metadata.bind = self.engine
+
+    def init_db(self) -> None:
+        """Creates all tables defined in the Base metadata."""
+        if not self.engine:
+            raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
+        Base.metadata.create_all(self.engine)
+
+    def get_session(self) -> SQLAlchemySession:
+        """Provides a session for database operations."""
+        if not self._session_factory:
+            raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
+        return self._session_factory()
+
+    def close_session(self) -> None:
+        """Closes the current session."""
+        if self._session_factory:
+            self._session_factory.remove()
+
+    def execute_transaction(self, transaction_block):
+        """Executes a block of code within a database transaction."""
+        session = self.get_session()
+        try:
+            transaction_block(session)
+            session.commit()
+        except Exception as e:
+            session.rollback()
+            raise e
+        finally:
+            self.close_session()
+
+
+# Singleton pattern to use throughout the application
+database_manager = DatabaseManager()
+
+
+# Convenience functions for backward compatibility and ease of use
+def setup_engine(database_uri: str, echo: bool = False) -> None:
+    database_manager.database_uri = database_uri
+    database_manager.echo = echo
+    database_manager.setup_engine()
+
+
+def alembic_upgrade() -> None:
+    """Upgrades the database to the latest version."""
+    alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
+    alembic_cfg = Config(alembic_config_path)
+    command.upgrade(alembic_cfg, "head")
+
+
+def init_db() -> None:
+    alembic_upgrade()
+
+
+def get_session() -> SQLAlchemySession:
+    return database_manager.get_session()
+
+
+def execute_transaction(transaction_block):
+    database_manager.execute_transaction(transaction_block)

+ 31 - 0
embedchain/embedchain/core/db/models.py

@@ -0,0 +1,31 @@
+import uuid
+
+from sqlalchemy import TIMESTAMP, Column, Integer, String, Text, func
+from sqlalchemy.orm import declarative_base
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
+class DataSource(Base):
+    __tablename__ = "ec_data_sources"
+
+    id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
+    app_id = Column(Text, index=True)
+    hash = Column(Text, index=True)
+    type = Column(Text, index=True)
+    value = Column(Text)
+    meta_data = Column(Text, name="metadata")
+    is_uploaded = Column(Integer, default=0)
+
+
+class ChatHistory(Base):
+    __tablename__ = "ec_chat_history"
+
+    app_id = Column(String, primary_key=True)
+    id = Column(String, primary_key=True)
+    session_id = Column(String, primary_key=True, index=True)
+    question = Column(Text)
+    answer = Column(Text)
+    meta_data = Column(Text, name="metadata")
+    created_at = Column(TIMESTAMP, default=func.current_timestamp(), index=True)

+ 2 - 1
embedchain/embedchain/embedder/gpt4all.py

@@ -12,7 +12,8 @@ class GPT4AllEmbedder(BaseEmbedder):
         from langchain_community.embeddings import GPT4AllEmbeddings as LangchainGPT4AllEmbeddings
 
         model_name = self.config.model or "all-MiniLM-L6-v2-f16.gguf"
-        embeddings = LangchainGPT4AllEmbeddings(model_name=model_name)
+        gpt4all_kwargs = {'allow_download': 'True'}
+        embeddings = LangchainGPT4AllEmbeddings(model_name=model_name, gpt4all_kwargs=gpt4all_kwargs)
         embedding_fn = BaseEmbedder._langchain_default_concept(embeddings)
         self.set_embedding_fn(embedding_fn=embedding_fn)
 

+ 1 - 1
embedchain/pyproject.toml

@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "embedchain"
-version = "0.1.116"
+version = "0.1.117"
 description = "Simplest open source retrieval (RAG) framework"
 authors = [
     "Taranjeet Singh <taranjeet@embedchain.ai>",

+ 2 - 2
embedchain/tests/test_app.py

@@ -13,8 +13,8 @@ from embedchain.vectordb.chroma import ChromaDB
 
 @pytest.fixture
 def app():
-    os.environ["OPENAI_API_KEY"] = "test_api_key"
-    os.environ["OPENAI_API_BASE"] = "test_api_base"
+    os.environ["OPENAI_API_KEY"] = "test-api-key"
+    os.environ["OPENAI_API_BASE"] = "test-api-base"
     return App()