database.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. from alembic import command
  3. from alembic.config import Config
  4. from sqlalchemy import create_engine
  5. from sqlalchemy.engine.base import Engine
  6. from sqlalchemy.orm import Session as SQLAlchemySession
  7. from sqlalchemy.orm import scoped_session, sessionmaker
  8. from .models import Base
  9. class DatabaseManager:
  10. def __init__(self, database_uri: str = "sqlite:///embedchain.db", echo: bool = False):
  11. self.database_uri = database_uri
  12. self.echo = echo
  13. self.engine: Engine = None
  14. self._session_factory = None
  15. def setup_engine(self) -> None:
  16. """Initializes the database engine and session factory."""
  17. self.engine = create_engine(self.database_uri, echo=self.echo, connect_args={"check_same_thread": False})
  18. self._session_factory = scoped_session(sessionmaker(bind=self.engine))
  19. Base.metadata.bind = self.engine
  20. def init_db(self) -> None:
  21. """Creates all tables defined in the Base metadata."""
  22. if not self.engine:
  23. raise RuntimeError("Database engine is not initialized. Call setup_engine() first.")
  24. Base.metadata.create_all(self.engine)
  25. def get_session(self) -> SQLAlchemySession:
  26. """Provides a session for database operations."""
  27. if not self._session_factory:
  28. raise RuntimeError("Session factory is not initialized. Call setup_engine() first.")
  29. return self._session_factory()
  30. def close_session(self) -> None:
  31. """Closes the current session."""
  32. if self._session_factory:
  33. self._session_factory.remove()
  34. def execute_transaction(self, transaction_block):
  35. """Executes a block of code within a database transaction."""
  36. session = self.get_session()
  37. try:
  38. transaction_block(session)
  39. session.commit()
  40. except Exception as e:
  41. session.rollback()
  42. raise e
  43. finally:
  44. self.close_session()
  45. # Singleton pattern to use throughout the application
  46. database_manager = DatabaseManager()
  47. # Convenience functions for backward compatibility and ease of use
  48. def setup_engine(database_uri: str = "sqlite:///embedchain.db", echo: bool = False) -> None:
  49. database_manager.database_uri = database_uri
  50. database_manager.echo = echo
  51. database_manager.setup_engine()
  52. def alembic_upgrade() -> None:
  53. """Upgrades the database to the latest version."""
  54. alembic_config_path = os.path.join(os.path.dirname(__file__), "..", "..", "alembic.ini")
  55. alembic_cfg = Config(alembic_config_path)
  56. command.upgrade(alembic_cfg, "head")
  57. def init_db() -> None:
  58. alembic_upgrade()
  59. def get_session() -> SQLAlchemySession:
  60. return database_manager.get_session()
  61. def execute_transaction(transaction_block):
  62. database_manager.execute_transaction(transaction_block)