12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import os
- from alembic import command
- from alembic.config import Config
- from sqlalchemy import create_engine
- from sqlalchemy.engine.base import Engine
- from sqlalchemy.orm import Session as SQLAlchemySession
- from sqlalchemy.orm import scoped_session, sessionmaker
- from .models import Base
- class DatabaseManager:
- def __init__(self, echo: bool = False):
- self.database_uri = os.environ.get("EMBEDCHAIN_DB_URI")
- self.echo = echo
- self.engine: Engine = None
- self._session_factory = None
- def setup_engine(self) -> None:
- """Initializes the database engine and session factory."""
- if not self.database_uri:
- raise RuntimeError("Database URI is not set. Set the EMBEDCHAIN_DB_URI environment variable.")
- connect_args = {}
- if self.database_uri.startswith("sqlite"):
- connect_args["check_same_thread"] = False
- self.engine = create_engine(self.database_uri, echo=self.echo, connect_args=connect_args)
- self._session_factory = scoped_session(sessionmaker(bind=self.engine))
- Base.metadata.bind = self.engine
- def init_db(self) -> None:
- """Creates all tables defined in the Base metadata."""
- if not self.engine:
- raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
- Base.metadata.create_all(self.engine)
- def get_session(self) -> SQLAlchemySession:
- """Provides a session for database operations."""
- if not self._session_factory:
- raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
- return self._session_factory()
- def close_session(self) -> None:
- """Closes the current session."""
- if self._session_factory:
- self._session_factory.remove()
- def execute_transaction(self, transaction_block):
- """Executes a block of code within a database transaction."""
- session = self.get_session()
- try:
- transaction_block(session)
- session.commit()
- except Exception as e:
- session.rollback()
- raise e
- finally:
- self.close_session()
- # Singleton pattern to use throughout the application
- database_manager = DatabaseManager()
- # Convenience functions for backward compatibility and ease of use
- def setup_engine(database_uri: str, echo: bool = False) -> None:
- database_manager.database_uri = database_uri
- database_manager.echo = echo
- database_manager.setup_engine()
- def alembic_upgrade() -> None:
- """Upgrades the database to the latest version."""
- alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
- alembic_cfg = Config(alembic_config_path)
- command.upgrade(alembic_cfg, "head")
- def init_db() -> None:
- alembic_upgrade()
- def get_session() -> SQLAlchemySession:
- return database_manager.get_session()
- def execute_transaction(transaction_block):
- database_manager.execute_transaction(transaction_block)
|