|
@@ -0,0 +1,88 @@
|
|
|
+import os
|
|
|
+
|
|
|
+from alembic import command
|
|
|
+from alembic.config import Config
|
|
|
+from sqlalchemy import create_engine
|
|
|
+from sqlalchemy.engine.base import Engine
|
|
|
+from sqlalchemy.orm import Session as SQLAlchemySession
|
|
|
+from sqlalchemy.orm import scoped_session, sessionmaker
|
|
|
+
|
|
|
+from .models import Base
|
|
|
+
|
|
|
+
|
|
|
+class DatabaseManager:
|
|
|
+ def __init__(self, echo: bool = False):
|
|
|
+ self.database_uri = os.environ.get("EMBEDCHAIN_DB_URI")
|
|
|
+ self.echo = echo
|
|
|
+ self.engine: Engine = None
|
|
|
+ self._session_factory = None
|
|
|
+
|
|
|
+ def setup_engine(self) -> None:
|
|
|
+ """Initializes the database engine and session factory."""
|
|
|
+ if not self.database_uri:
|
|
|
+ raise RuntimeError("Database URI is not set. Set the EMBEDCHAIN_DB_URI environment variable.")
|
|
|
+ connect_args = {}
|
|
|
+ if self.database_uri.startswith("sqlite"):
|
|
|
+ connect_args["check_same_thread"] = False
|
|
|
+ self.engine = create_engine(self.database_uri, echo=self.echo, connect_args=connect_args)
|
|
|
+ self._session_factory = scoped_session(sessionmaker(bind=self.engine))
|
|
|
+ Base.metadata.bind = self.engine
|
|
|
+
|
|
|
+ def init_db(self) -> None:
|
|
|
+ """Creates all tables defined in the Base metadata."""
|
|
|
+ if not self.engine:
|
|
|
+ raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
|
|
|
+ Base.metadata.create_all(self.engine)
|
|
|
+
|
|
|
+ def get_session(self) -> SQLAlchemySession:
|
|
|
+ """Provides a session for database operations."""
|
|
|
+ if not self._session_factory:
|
|
|
+ raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
|
|
|
+ return self._session_factory()
|
|
|
+
|
|
|
+ def close_session(self) -> None:
|
|
|
+ """Closes the current session."""
|
|
|
+ if self._session_factory:
|
|
|
+ self._session_factory.remove()
|
|
|
+
|
|
|
+ def execute_transaction(self, transaction_block):
|
|
|
+ """Executes a block of code within a database transaction."""
|
|
|
+ session = self.get_session()
|
|
|
+ try:
|
|
|
+ transaction_block(session)
|
|
|
+ session.commit()
|
|
|
+ except Exception as e:
|
|
|
+ session.rollback()
|
|
|
+ raise e
|
|
|
+ finally:
|
|
|
+ self.close_session()
|
|
|
+
|
|
|
+
|
|
|
+# Singleton pattern to use throughout the application
|
|
|
+database_manager = DatabaseManager()
|
|
|
+
|
|
|
+
|
|
|
+# Convenience functions for backward compatibility and ease of use
|
|
|
+def setup_engine(database_uri: str, echo: bool = False) -> None:
|
|
|
+ database_manager.database_uri = database_uri
|
|
|
+ database_manager.echo = echo
|
|
|
+ database_manager.setup_engine()
|
|
|
+
|
|
|
+
|
|
|
+def alembic_upgrade() -> None:
|
|
|
+ """Upgrades the database to the latest version."""
|
|
|
+ alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
|
|
|
+ alembic_cfg = Config(alembic_config_path)
|
|
|
+ command.upgrade(alembic_cfg, "head")
|
|
|
+
|
|
|
+
|
|
|
+def init_db() -> None:
|
|
|
+ alembic_upgrade()
|
|
|
+
|
|
|
+
|
|
|
+def get_session() -> SQLAlchemySession:
|
|
|
+ return database_manager.get_session()
|
|
|
+
|
|
|
+
|
|
|
+def execute_transaction(transaction_block):
|
|
|
+ database_manager.execute_transaction(transaction_block)
|