database.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. """
  2. 数据库连接管理
  3. """
  4. from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
  5. from sqlalchemy.orm import declarative_base
  6. from sqlalchemy import create_engine
  7. from sqlalchemy.orm import sessionmaker
  8. from app.core.config import settings
  9. # 异步引擎(用于FastAPI)
  10. async_engine = create_async_engine(
  11. settings.DATABASE_URL,
  12. echo=settings.DEBUG,
  13. pool_pre_ping=True,
  14. pool_size=10,
  15. max_overflow=20,
  16. )
  17. # 异步会话工厂
  18. AsyncSessionLocal = async_sessionmaker(
  19. async_engine,
  20. class_=AsyncSession,
  21. expire_on_commit=False,
  22. )
  23. # 同步引擎(用于Alembic迁移)
  24. sync_engine = create_engine(
  25. settings.DATABASE_URL_SYNC,
  26. echo=settings.DEBUG,
  27. pool_pre_ping=True,
  28. )
  29. # 同步会话工厂
  30. SessionLocal = sessionmaker(
  31. autocommit=False,
  32. autoflush=False,
  33. bind=sync_engine,
  34. )
  35. # 基础模型类
  36. Base = declarative_base()
  37. async def get_db() -> AsyncSession:
  38. """
  39. 获取数据库会话(依赖注入)
  40. """
  41. async with AsyncSessionLocal() as session:
  42. try:
  43. yield session
  44. await session.commit()
  45. except Exception:
  46. await session.rollback()
  47. raise
  48. finally:
  49. await session.close()