| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- """
- 数据库连接管理
- """
- from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
- from sqlalchemy.orm import declarative_base
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker
- from app.core.config import settings
- # 异步引擎(用于FastAPI)
- async_engine = create_async_engine(
- settings.DATABASE_URL,
- echo=settings.DEBUG,
- pool_pre_ping=True,
- pool_size=10,
- max_overflow=20,
- )
- # 异步会话工厂
- AsyncSessionLocal = async_sessionmaker(
- async_engine,
- class_=AsyncSession,
- expire_on_commit=False,
- )
- # 同步引擎(用于Alembic迁移)
- sync_engine = create_engine(
- settings.DATABASE_URL_SYNC,
- echo=settings.DEBUG,
- pool_pre_ping=True,
- )
- # 同步会话工厂
- SessionLocal = sessionmaker(
- autocommit=False,
- autoflush=False,
- bind=sync_engine,
- )
- # 基础模型类
- Base = declarative_base()
- async def get_db() -> AsyncSession:
- """
- 获取数据库会话(依赖注入)
- """
- async with AsyncSessionLocal() as session:
- try:
- yield session
- await session.commit()
- except Exception:
- await session.rollback()
- raise
- finally:
- await session.close()
|