1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import os
- from alembic import command
- from alembic.config import Config
- from sqlalchemy import create_engine
- from sqlalchemy.engine.base import Engine
- from sqlalchemy.orm import Session as SQLAlchemySession
- from sqlalchemy.orm import scoped_session, sessionmaker
- from .models import Base
- class DatabaseManager:
- def __init__(self, database_uri: str = "sqlite:///embedchain.db", echo: bool = False):
- self.database_uri = database_uri
- self.echo = echo
- self.engine: Engine = None
- self._session_factory = None
- def setup_engine(self) -> None:
- """Initializes the database engine and session factory."""
- self.engine = create_engine(self.database_uri, echo=self.echo, connect_args={"check_same_thread": False})
- self._session_factory = scoped_session(sessionmaker(bind=self.engine))
- Base.metadata.bind = self.engine
- def init_db(self) -> None:
- """Creates all tables defined in the Base metadata."""
- if not self.engine:
- raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
- Base.metadata.create_all(self.engine)
- def get_session(self) -> SQLAlchemySession:
- """Provides a session for database operations."""
- if not self._session_factory:
- raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
- return self._session_factory()
- def close_session(self) -> None:
- """Closes the current session."""
- if self._session_factory:
- self._session_factory.remove()
- def execute_transaction(self, transaction_block):
- """Executes a block of code within a database transaction."""
- session = self.get_session()
- try:
- transaction_block(session)
- session.commit()
- except Exception as e:
- session.rollback()
- raise e
- finally:
- self.close_session()
- # Singleton pattern to use throughout the application
- database_manager = DatabaseManager()
- # Convenience functions for backward compatibility and ease of use
- def setup_engine(database_uri: str = "sqlite:///embedchain.db", echo: bool = False) -> None:
- database_manager.database_uri = database_uri
- database_manager.echo = echo
- database_manager.setup_engine()
- def alembic_upgrade() -> None:
- """Upgrades the database to the latest version."""
- alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
- alembic_cfg = Config(alembic_config_path)
- command.upgrade(alembic_cfg, "head")
- def init_db() -> None:
- alembic_upgrade()
- def get_session() -> SQLAlchemySession:
- return database_manager.get_session()
- def execute_transaction(transaction_block):
- database_manager.execute_transaction(transaction_block)
|