database.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import os
  2. from alembic import command
  3. from alembic.config import Config
  4. from sqlalchemy import create_engine
  5. from sqlalchemy.engine.base import Engine
  6. from sqlalchemy.orm import Session as SQLAlchemySession
  7. from sqlalchemy.orm import scoped_session, sessionmaker
  8. from .models import Base
  9. class DatabaseManager:
  10. def __init__(self, echo: bool = False):
  11. self.database_uri = os.environ.get("EMBEDCHAIN_DB_URI")
  12. self.echo = echo
  13. self.engine: Engine = None
  14. self._session_factory = None
  15. def setup_engine(self) -> None:
  16. """Initializes the database engine and session factory."""
  17. if not self.database_uri:
  18. raise RuntimeError("Database URI is not set. Set the EMBEDCHAIN_DB_URI environment variable.")
  19. connect_args = {}
  20. if self.database_uri.startswith("sqlite"):
  21. connect_args["check_same_thread"] = False
  22. self.engine = create_engine(self.database_uri, echo=self.echo, connect_args=connect_args)
  23. self._session_factory = scoped_session(sessionmaker(bind=self.engine))
  24. Base.metadata.bind = self.engine
  25. def init_db(self) -> None:
  26. """Creates all tables defined in the Base metadata."""
  27. if not self.engine:
  28. raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
  29. Base.metadata.create_all(self.engine)
  30. def get_session(self) -> SQLAlchemySession:
  31. """Provides a session for database operations."""
  32. if not self._session_factory:
  33. raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
  34. return self._session_factory()
  35. def close_session(self) -> None:
  36. """Closes the current session."""
  37. if self._session_factory:
  38. self._session_factory.remove()
  39. def execute_transaction(self, transaction_block):
  40. """Executes a block of code within a database transaction."""
  41. session = self.get_session()
  42. try:
  43. transaction_block(session)
  44. session.commit()
  45. except Exception as e:
  46. session.rollback()
  47. raise e
  48. finally:
  49. self.close_session()
  50. # Singleton pattern to use throughout the application
  51. database_manager = DatabaseManager()
  52. # Convenience functions for backward compatibility and ease of use
  53. def setup_engine(database_uri: str, echo: bool = False) -> None:
  54. database_manager.database_uri = database_uri
  55. database_manager.echo = echo
  56. database_manager.setup_engine()
  57. def alembic_upgrade() -> None:
  58. """Upgrades the database to the latest version."""
  59. alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
  60. alembic_cfg = Config(alembic_config_path)
  61. command.upgrade(alembic_cfg, "head")
  62. def init_db() -> None:
  63. alembic_upgrade()
  64. def get_session() -> SQLAlchemySession:
  65. return database_manager.get_session()
  66. def execute_transaction(transaction_block):
  67. database_manager.execute_transaction(transaction_block)