""" 数据库连接与会话管理 使用 SQLAlchemy 2.0 异步模式。 """ import logging from collections.abc import AsyncGenerator from pathlib import Path from sqlalchemy.ext.asyncio import ( AsyncSession, async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import DeclarativeBase from app.core.config import settings logger = logging.getLogger(__name__) class Base(DeclarativeBase): """SQLAlchemy 声明式基类""" pass def _ensure_sqlite_dir() -> None: """ 确保 SQLite 数据库目录存在 如果配置使用 SQLite,自动创建数据库文件所在的目录。 """ if settings.database_type != "sqlite": return db_path = Path(settings.database_sqlite_path) db_dir = db_path.parent if not db_dir.exists(): logger.info(f"创建数据库目录: {db_dir.absolute()}") db_dir.mkdir(parents=True, exist_ok=True) # 确保 SQLite 目录存在 _ensure_sqlite_dir() # 创建异步引擎 engine = create_async_engine( settings.database_url, echo=settings.database_echo, future=True, ) # 创建异步会话工厂 async_session_factory = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, autocommit=False, autoflush=False, ) async def get_db() -> AsyncGenerator[AsyncSession, None]: """ 获取数据库会话(依赖注入用) Yields: 异步数据库会话 """ async with async_session_factory() as session: try: yield session finally: await session.close() async def init_db() -> None: """初始化数据库(创建所有表)""" # 导入所有模型以确保它们被注册 from app.models import ( # noqa: F401 User, UserBalance, BalanceTransaction, RedeemCode, RedeemCodeBatch, RedeemCodeUsageLog, ) logger.info(f"初始化数据库: {settings.database_url}") async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) logger.info("数据库初始化完成") async def close_db() -> None: """关闭数据库连接""" await engine.dispose()