Files
SatoNano/app/database.py
2026-01-06 23:49:23 +08:00

101 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库连接与会话管理
使用 SQLAlchemy 2.0 异步模式。
"""
import logging
from collections.abc import AsyncGenerator
from pathlib import Path
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
"""SQLAlchemy 声明式基类"""
pass
def _ensure_sqlite_dir() -> None:
"""
确保 SQLite 数据库目录存在
如果配置使用 SQLite自动创建数据库文件所在的目录。
"""
if settings.database_type != "sqlite":
return
db_path = Path(settings.database_sqlite_path)
db_dir = db_path.parent
if not db_dir.exists():
logger.info(f"创建数据库目录: {db_dir.absolute()}")
db_dir.mkdir(parents=True, exist_ok=True)
# 确保 SQLite 目录存在
_ensure_sqlite_dir()
# 创建异步引擎
engine = create_async_engine(
settings.database_url,
echo=settings.database_echo,
future=True,
)
# 创建异步会话工厂
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
获取数据库会话(依赖注入用)
Yields:
异步数据库会话
"""
async with async_session_factory() as session:
try:
yield session
finally:
await session.close()
async def init_db() -> None:
"""初始化数据库(创建所有表)"""
# 导入所有模型以确保它们被注册
from app.models import ( # noqa: F401
User,
UserBalance,
BalanceTransaction,
RedeemCode,
RedeemCodeBatch,
RedeemCodeUsageLog,
)
logger.info(f"初始化数据库: {settings.database_url}")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("数据库初始化完成")
async def close_db() -> None:
"""关闭数据库连接"""
await engine.dispose()