101 lines
2.2 KiB
Python
101 lines
2.2 KiB
Python
"""
|
||
数据库连接与会话管理
|
||
|
||
使用 SQLAlchemy 2.0 异步模式。
|
||
"""
|
||
|
||
import logging
|
||
from collections.abc import AsyncGenerator
|
||
from pathlib import Path
|
||
|
||
from sqlalchemy.ext.asyncio import (
|
||
AsyncSession,
|
||
async_sessionmaker,
|
||
create_async_engine,
|
||
)
|
||
from sqlalchemy.orm import DeclarativeBase
|
||
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class Base(DeclarativeBase):
|
||
"""SQLAlchemy 声明式基类"""
|
||
pass
|
||
|
||
|
||
def _ensure_sqlite_dir() -> None:
|
||
"""
|
||
确保 SQLite 数据库目录存在
|
||
|
||
如果配置使用 SQLite,自动创建数据库文件所在的目录。
|
||
"""
|
||
if settings.database_type != "sqlite":
|
||
return
|
||
|
||
db_path = Path(settings.database_sqlite_path)
|
||
db_dir = db_path.parent
|
||
|
||
if not db_dir.exists():
|
||
logger.info(f"创建数据库目录: {db_dir.absolute()}")
|
||
db_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
|
||
# 确保 SQLite 目录存在
|
||
_ensure_sqlite_dir()
|
||
|
||
# 创建异步引擎
|
||
engine = create_async_engine(
|
||
settings.database_url,
|
||
echo=settings.database_echo,
|
||
future=True,
|
||
)
|
||
|
||
# 创建异步会话工厂
|
||
async_session_factory = async_sessionmaker(
|
||
engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autocommit=False,
|
||
autoflush=False,
|
||
)
|
||
|
||
|
||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||
"""
|
||
获取数据库会话(依赖注入用)
|
||
|
||
Yields:
|
||
异步数据库会话
|
||
"""
|
||
async with async_session_factory() as session:
|
||
try:
|
||
yield session
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
async def init_db() -> None:
|
||
"""初始化数据库(创建所有表)"""
|
||
# 导入所有模型以确保它们被注册
|
||
from app.models import ( # noqa: F401
|
||
User,
|
||
UserBalance,
|
||
BalanceTransaction,
|
||
RedeemCode,
|
||
RedeemCodeBatch,
|
||
RedeemCodeUsageLog,
|
||
)
|
||
|
||
logger.info(f"初始化数据库: {settings.database_url}")
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
logger.info("数据库初始化完成")
|
||
|
||
|
||
async def close_db() -> None:
|
||
"""关闭数据库连接"""
|
||
await engine.dispose()
|
||
|