提供基本前后端骨架
This commit is contained in:
6
app/__init__.py
Normal file
6
app/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
SatoNano - 现代化用户认证系统
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
2
app/api/__init__.py
Normal file
2
app/api/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""API 模块"""
|
||||
|
||||
108
app/api/deps.py
Normal file
108
app/api/deps.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
API 依赖注入
|
||||
|
||||
定义 FastAPI 依赖项。
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import (
|
||||
TokenError,
|
||||
UserDisabledError,
|
||||
)
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.services.auth import AuthService
|
||||
|
||||
# HTTP Bearer 认证方案
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_auth_service(
|
||||
session: Annotated[AsyncSession, Depends(get_db)],
|
||||
) -> AuthService:
|
||||
"""获取认证服务实例"""
|
||||
return AuthService(session)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: Annotated[
|
||||
HTTPAuthorizationCredentials | None,
|
||||
Depends(security),
|
||||
],
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> User:
|
||||
"""
|
||||
获取当前认证用户
|
||||
|
||||
从请求头的 Bearer Token 中解析用户信息。
|
||||
|
||||
Raises:
|
||||
HTTPException 401: 未认证
|
||||
HTTPException 403: 用户被禁用
|
||||
"""
|
||||
if not credentials:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未提供认证信息",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
try:
|
||||
user = await auth_service.get_current_user(credentials.credentials)
|
||||
return user
|
||||
except TokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=e.message,
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except UserDisabledError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
"""
|
||||
获取当前活跃用户
|
||||
|
||||
确保用户处于激活状态。
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="账户已被禁用",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_current_superuser(
|
||||
current_user: Annotated[User, Depends(get_current_active_user)],
|
||||
) -> User:
|
||||
"""
|
||||
获取当前超级管理员
|
||||
|
||||
确保用户具有超级管理员权限。
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="需要管理员权限",
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
# 类型别名,简化依赖注入声明
|
||||
DbSession = Annotated[AsyncSession, Depends(get_db)]
|
||||
CurrentUser = Annotated[User, Depends(get_current_user)]
|
||||
ActiveUser = Annotated[User, Depends(get_current_active_user)]
|
||||
SuperUser = Annotated[User, Depends(get_current_superuser)]
|
||||
|
||||
2
app/api/v1/__init__.py
Normal file
2
app/api/v1/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""API v1 模块"""
|
||||
|
||||
2
app/api/v1/endpoints/__init__.py
Normal file
2
app/api/v1/endpoints/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""API v1 端点"""
|
||||
|
||||
6
app/api/v1/endpoints/admin/__init__.py
Normal file
6
app/api/v1/endpoints/admin/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""管理员 API 端点"""
|
||||
|
||||
from app.api.v1.endpoints.admin import redeem_codes
|
||||
|
||||
__all__ = ["redeem_codes"]
|
||||
|
||||
613
app/api/v1/endpoints/admin/redeem_codes.py
Normal file
613
app/api/v1/endpoints/admin/redeem_codes.py
Normal file
@@ -0,0 +1,613 @@
|
||||
"""
|
||||
管理员 - 兑换码管理 API
|
||||
|
||||
包括批量生成、导入导出、查询使用日志等接口。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.deps import SuperUser, DbSession
|
||||
from app.core.exceptions import ResourceNotFoundError, ValidationError
|
||||
from app.schemas.base import APIResponse, PaginatedResponse
|
||||
from app.schemas.balance import (
|
||||
AdminAdjustmentRequest,
|
||||
AdminBalanceResponse,
|
||||
format_display,
|
||||
)
|
||||
from app.schemas.redeem_code import (
|
||||
BatchCreateRequest,
|
||||
BatchResponse,
|
||||
BatchDetailResponse,
|
||||
BatchListResponse,
|
||||
RedeemCodeDetailResponse,
|
||||
RedeemCodeListResponse,
|
||||
BulkImportRequest,
|
||||
BulkImportResponse,
|
||||
ExportResponse,
|
||||
ExportCodeItem,
|
||||
UsageLogResponse,
|
||||
UsageLogListResponse,
|
||||
)
|
||||
from app.services.balance import BalanceService, InsufficientBalanceError
|
||||
from app.services.redeem_code import RedeemCodeService
|
||||
from app.models.redeem_code import RedeemCodeStatus
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 批次管理
|
||||
# ============================================================
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
response_model=APIResponse[BatchResponse],
|
||||
summary="创建兑换码批次",
|
||||
description="批量生成指定数量的兑换码",
|
||||
)
|
||||
async def create_batch(
|
||||
request: BatchCreateRequest,
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[BatchResponse]:
|
||||
"""
|
||||
创建兑换码批次
|
||||
|
||||
批量生成指定数量的兑换码。
|
||||
|
||||
- **name**: 批次名称
|
||||
- **face_value**: 面值(如 10.00)
|
||||
- **count**: 生成数量(最大 10000)
|
||||
- **max_uses**: 每个兑换码最大使用次数
|
||||
- **expires_at**: 过期时间(可选)
|
||||
"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
try:
|
||||
batch = await redeem_service.create_batch(
|
||||
name=request.name,
|
||||
face_value_units=request.face_value_units,
|
||||
count=request.count,
|
||||
created_by=current_user.id,
|
||||
description=request.description,
|
||||
max_uses=request.max_uses,
|
||||
expires_at=request.expires_at,
|
||||
)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=BatchResponse(
|
||||
id=batch.id,
|
||||
name=batch.name,
|
||||
description=batch.description,
|
||||
face_value_units=batch.face_value,
|
||||
total_count=batch.total_count,
|
||||
used_count=batch.used_count,
|
||||
created_by=batch.created_by,
|
||||
created_at=batch.created_at,
|
||||
),
|
||||
message=f"成功创建批次,生成 {request.count} 个兑换码",
|
||||
)
|
||||
|
||||
except ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/batches",
|
||||
response_model=APIResponse[BatchListResponse],
|
||||
summary="获取批次列表",
|
||||
description="获取所有兑换码批次",
|
||||
)
|
||||
async def get_batches(
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
) -> APIResponse[BatchListResponse]:
|
||||
"""获取批次列表"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
batches, total = await redeem_service.get_batches(
|
||||
offset=offset,
|
||||
limit=page_size,
|
||||
)
|
||||
|
||||
items = [
|
||||
BatchResponse(
|
||||
id=b.id,
|
||||
name=b.name,
|
||||
description=b.description,
|
||||
face_value_units=b.face_value,
|
||||
total_count=b.total_count,
|
||||
used_count=b.used_count,
|
||||
created_by=b.created_by,
|
||||
created_at=b.created_at,
|
||||
)
|
||||
for b in batches
|
||||
]
|
||||
|
||||
return APIResponse.ok(
|
||||
data=BatchListResponse.create(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
response_model=APIResponse[BatchDetailResponse],
|
||||
summary="获取批次详情",
|
||||
description="获取指定批次的详细信息",
|
||||
)
|
||||
async def get_batch_detail(
|
||||
batch_id: str,
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[BatchDetailResponse]:
|
||||
"""获取批次详情"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
try:
|
||||
batch = await redeem_service.get_batch_detail(batch_id)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=BatchDetailResponse(
|
||||
id=batch.id,
|
||||
name=batch.name,
|
||||
description=batch.description,
|
||||
face_value_units=batch.face_value,
|
||||
total_count=batch.total_count,
|
||||
used_count=batch.used_count,
|
||||
created_by=batch.created_by,
|
||||
created_at=batch.created_at,
|
||||
),
|
||||
)
|
||||
|
||||
except ResourceNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 兑换码管理
|
||||
# ============================================================
|
||||
|
||||
@router.get(
|
||||
"/codes",
|
||||
response_model=APIResponse[RedeemCodeListResponse],
|
||||
summary="获取兑换码列表",
|
||||
description="获取所有兑换码,支持多种过滤条件",
|
||||
)
|
||||
async def get_codes(
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
status_filter: RedeemCodeStatus | None = Query(None, alias="status"),
|
||||
batch_id: str | None = Query(None),
|
||||
code: str | None = Query(None, description="兑换码模糊搜索"),
|
||||
created_after: datetime | None = Query(None),
|
||||
created_before: datetime | None = Query(None),
|
||||
) -> APIResponse[RedeemCodeListResponse]:
|
||||
"""
|
||||
获取兑换码列表
|
||||
|
||||
支持过滤:
|
||||
- **status**: 状态(active/used/disabled/expired)
|
||||
- **batch_id**: 批次 ID
|
||||
- **code**: 兑换码模糊搜索
|
||||
- **created_after/created_before**: 创建时间范围
|
||||
"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
codes, total = await redeem_service.get_codes(
|
||||
offset=offset,
|
||||
limit=page_size,
|
||||
status=status_filter,
|
||||
batch_id=batch_id,
|
||||
code_like=code,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
|
||||
items = [
|
||||
RedeemCodeDetailResponse(
|
||||
id=c.id,
|
||||
code=c.code,
|
||||
face_value_units=c.face_value,
|
||||
status=c.status,
|
||||
max_uses=c.max_uses,
|
||||
used_count=c.used_count,
|
||||
expires_at=c.expires_at,
|
||||
used_at=c.used_at,
|
||||
created_at=c.created_at,
|
||||
batch_id=c.batch_id,
|
||||
batch_name=c.batch.name if c.batch else None,
|
||||
remark=c.remark,
|
||||
created_by=c.created_by,
|
||||
used_by=c.used_by,
|
||||
)
|
||||
for c in codes
|
||||
]
|
||||
|
||||
return APIResponse.ok(
|
||||
data=RedeemCodeListResponse.create(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/codes/{code_id}",
|
||||
response_model=APIResponse[RedeemCodeDetailResponse],
|
||||
summary="获取兑换码详情",
|
||||
)
|
||||
async def get_code_detail(
|
||||
code_id: str,
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[RedeemCodeDetailResponse]:
|
||||
"""获取兑换码详情"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
try:
|
||||
code = await redeem_service.get_code_detail(code_id)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=RedeemCodeDetailResponse(
|
||||
id=code.id,
|
||||
code=code.code,
|
||||
face_value_units=code.face_value,
|
||||
status=code.status,
|
||||
max_uses=code.max_uses,
|
||||
used_count=code.used_count,
|
||||
expires_at=code.expires_at,
|
||||
used_at=code.used_at,
|
||||
created_at=code.created_at,
|
||||
batch_id=code.batch_id,
|
||||
batch_name=code.batch.name if code.batch else None,
|
||||
remark=code.remark,
|
||||
created_by=code.created_by,
|
||||
used_by=code.used_by,
|
||||
),
|
||||
)
|
||||
|
||||
except ResourceNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/codes/{code_id}/disable",
|
||||
response_model=APIResponse[RedeemCodeDetailResponse],
|
||||
summary="禁用兑换码",
|
||||
)
|
||||
async def disable_code(
|
||||
code_id: str,
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[RedeemCodeDetailResponse]:
|
||||
"""禁用指定兑换码"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
try:
|
||||
code = await redeem_service.disable_code(code_id)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=RedeemCodeDetailResponse(
|
||||
id=code.id,
|
||||
code=code.code,
|
||||
face_value_units=code.face_value,
|
||||
status=code.status,
|
||||
max_uses=code.max_uses,
|
||||
used_count=code.used_count,
|
||||
expires_at=code.expires_at,
|
||||
used_at=code.used_at,
|
||||
created_at=code.created_at,
|
||||
batch_id=code.batch_id,
|
||||
remark=code.remark,
|
||||
created_by=code.created_by,
|
||||
used_by=code.used_by,
|
||||
),
|
||||
message="兑换码已禁用",
|
||||
)
|
||||
|
||||
except ResourceNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/codes/{code_id}/enable",
|
||||
response_model=APIResponse[RedeemCodeDetailResponse],
|
||||
summary="启用兑换码",
|
||||
)
|
||||
async def enable_code(
|
||||
code_id: str,
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[RedeemCodeDetailResponse]:
|
||||
"""重新启用指定兑换码"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
try:
|
||||
code = await redeem_service.enable_code(code_id)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=RedeemCodeDetailResponse(
|
||||
id=code.id,
|
||||
code=code.code,
|
||||
face_value_units=code.face_value,
|
||||
status=code.status,
|
||||
max_uses=code.max_uses,
|
||||
used_count=code.used_count,
|
||||
expires_at=code.expires_at,
|
||||
used_at=code.used_at,
|
||||
created_at=code.created_at,
|
||||
batch_id=code.batch_id,
|
||||
remark=code.remark,
|
||||
created_by=code.created_by,
|
||||
used_by=code.used_by,
|
||||
),
|
||||
message="兑换码已启用",
|
||||
)
|
||||
|
||||
except ResourceNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 导入导出
|
||||
# ============================================================
|
||||
|
||||
@router.post(
|
||||
"/codes/import",
|
||||
response_model=APIResponse[BulkImportResponse],
|
||||
summary="批量导入兑换码",
|
||||
description="批量导入自定义兑换码",
|
||||
)
|
||||
async def import_codes(
|
||||
request: BulkImportRequest,
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[BulkImportResponse]:
|
||||
"""
|
||||
批量导入兑换码
|
||||
|
||||
可以导入自定义格式的兑换码。
|
||||
|
||||
- **codes**: 兑换码列表
|
||||
- **batch_name**: 批次名称(可选)
|
||||
"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
# 转换请求数据
|
||||
codes_data = [
|
||||
{
|
||||
"code": c.code,
|
||||
"face_value_units": c.face_value_units,
|
||||
"max_uses": c.max_uses,
|
||||
"expires_at": c.expires_at,
|
||||
"remark": c.remark,
|
||||
}
|
||||
for c in request.codes
|
||||
]
|
||||
|
||||
result = await redeem_service.import_codes(
|
||||
codes_data,
|
||||
created_by=current_user.id,
|
||||
batch_name=request.batch_name,
|
||||
)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=BulkImportResponse(
|
||||
success_count=result["success_count"],
|
||||
failed_count=result["failed_count"],
|
||||
failed_codes=result["failed_codes"],
|
||||
batch_id=result["batch_id"],
|
||||
),
|
||||
message=f"导入完成:成功 {result['success_count']},失败 {result['failed_count']}",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/codes/export",
|
||||
response_model=APIResponse[ExportResponse],
|
||||
summary="导出兑换码",
|
||||
description="导出兑换码数据",
|
||||
)
|
||||
async def export_codes(
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
batch_id: str | None = Query(None),
|
||||
status_filter: RedeemCodeStatus | None = Query(None, alias="status"),
|
||||
limit: int = Query(10000, ge=1, le=50000),
|
||||
) -> APIResponse[ExportResponse]:
|
||||
"""
|
||||
导出兑换码
|
||||
|
||||
支持按批次或状态过滤。
|
||||
"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
codes = await redeem_service.export_codes(
|
||||
batch_id=batch_id,
|
||||
status=status_filter,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
items = [
|
||||
ExportCodeItem(
|
||||
code=c["code"],
|
||||
face_value=c["face_value"],
|
||||
status=c["status"],
|
||||
max_uses=c["max_uses"],
|
||||
used_count=c["used_count"],
|
||||
expires_at=c["expires_at"],
|
||||
created_at=c["created_at"],
|
||||
used_at=c["used_at"],
|
||||
used_by=c["used_by"],
|
||||
)
|
||||
for c in codes
|
||||
]
|
||||
|
||||
return APIResponse.ok(
|
||||
data=ExportResponse(
|
||||
total=len(items),
|
||||
codes=items,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 使用日志
|
||||
# ============================================================
|
||||
|
||||
@router.get(
|
||||
"/usage-logs",
|
||||
response_model=APIResponse[UsageLogListResponse],
|
||||
summary="获取使用日志",
|
||||
description="查询兑换码使用记录",
|
||||
)
|
||||
async def get_usage_logs(
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
redeem_code_id: str | None = Query(None),
|
||||
user_id: str | None = Query(None),
|
||||
code: str | None = Query(None, description="兑换码模糊搜索"),
|
||||
created_after: datetime | None = Query(None),
|
||||
created_before: datetime | None = Query(None),
|
||||
) -> APIResponse[UsageLogListResponse]:
|
||||
"""
|
||||
获取使用日志
|
||||
|
||||
支持过滤:
|
||||
- **redeem_code_id**: 兑换码 ID
|
||||
- **user_id**: 用户 ID
|
||||
- **code**: 兑换码模糊搜索
|
||||
- **created_after/created_before**: 使用时间范围
|
||||
"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
logs, total = await redeem_service.get_usage_logs(
|
||||
offset=offset,
|
||||
limit=page_size,
|
||||
redeem_code_id=redeem_code_id,
|
||||
user_id=user_id,
|
||||
code_like=code,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
|
||||
items = [
|
||||
UsageLogResponse(
|
||||
id=log.id,
|
||||
redeem_code_id=log.redeem_code_id,
|
||||
code_snapshot=log.code_snapshot,
|
||||
user_id=log.user_id,
|
||||
username=log.user.username if log.user else None,
|
||||
face_value=format_display(log.face_value),
|
||||
ip_address=log.ip_address,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
|
||||
return APIResponse.ok(
|
||||
data=UsageLogListResponse.create(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 余额管理
|
||||
# ============================================================
|
||||
|
||||
@router.post(
|
||||
"/balance/adjust",
|
||||
response_model=APIResponse[AdminBalanceResponse],
|
||||
summary="调整用户余额",
|
||||
description="管理员手动调整用户余额",
|
||||
)
|
||||
async def adjust_balance(
|
||||
request: AdminAdjustmentRequest,
|
||||
current_user: SuperUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[AdminBalanceResponse]:
|
||||
"""
|
||||
调整用户余额
|
||||
|
||||
- **user_id**: 目标用户 ID
|
||||
- **amount**: 调整金额(正数增加,负数减少)
|
||||
- **reason**: 调整原因
|
||||
"""
|
||||
balance_service = BalanceService(session)
|
||||
|
||||
try:
|
||||
transaction = await balance_service.admin_adjust(
|
||||
request.user_id,
|
||||
request.amount_units,
|
||||
operator_id=current_user.id,
|
||||
reason=request.reason,
|
||||
)
|
||||
|
||||
# 获取更新后的余额
|
||||
balance = await balance_service.get_balance(request.user_id)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=AdminBalanceResponse(
|
||||
user_id=balance.user_id,
|
||||
username=balance.user.username if balance.user else "",
|
||||
balance=format_display(balance.balance),
|
||||
available_balance=format_display(balance.available_balance),
|
||||
frozen_balance=format_display(balance.frozen_balance),
|
||||
total_recharged=format_display(balance.total_recharged),
|
||||
total_consumed=format_display(balance.total_consumed),
|
||||
version=balance.version,
|
||||
created_at=balance.created_at,
|
||||
updated_at=balance.updated_at,
|
||||
),
|
||||
message=f"余额调整成功,变动 {request.amount:+.2f}",
|
||||
)
|
||||
|
||||
except InsufficientBalanceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
191
app/api/v1/endpoints/auth.py
Normal file
191
app/api/v1/endpoints/auth.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
认证相关 API
|
||||
|
||||
包括注册、登录、退出、刷新令牌、修改密码等接口。
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import ActiveUser, DbSession, get_auth_service
|
||||
from app.core.exceptions import (
|
||||
InvalidCredentialsError,
|
||||
PasswordValidationError,
|
||||
ResourceConflictError,
|
||||
TokenError,
|
||||
UserDisabledError,
|
||||
)
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
PasswordChangeRequest,
|
||||
RefreshTokenRequest,
|
||||
TokenResponse,
|
||||
)
|
||||
from app.schemas.base import APIResponse
|
||||
from app.schemas.user import UserCreate, UserResponse
|
||||
from app.services.auth import AuthService
|
||||
from app.services.user import UserService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/register",
|
||||
response_model=APIResponse[UserResponse],
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="用户注册",
|
||||
description="创建新用户账户",
|
||||
)
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
session: DbSession,
|
||||
) -> APIResponse[UserResponse]:
|
||||
"""
|
||||
用户注册接口
|
||||
|
||||
- **username**: 用户名(字母开头,3-32位)
|
||||
- **email**: 邮箱(可选)
|
||||
- **password**: 密码(8位以上,需包含大小写字母和数字)
|
||||
- **nickname**: 昵称(可选)
|
||||
"""
|
||||
user_service = UserService(session)
|
||||
|
||||
try:
|
||||
user = await user_service.create_user(user_data)
|
||||
return APIResponse.ok(
|
||||
data=UserResponse.model_validate(user),
|
||||
message="注册成功",
|
||||
)
|
||||
except ResourceConflictError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
response_model=APIResponse[TokenResponse],
|
||||
summary="用户登录",
|
||||
description="使用用户名/邮箱和密码登录",
|
||||
)
|
||||
async def login(
|
||||
login_data: LoginRequest,
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> APIResponse[TokenResponse]:
|
||||
"""
|
||||
用户登录接口
|
||||
|
||||
- **username**: 用户名或邮箱
|
||||
- **password**: 密码
|
||||
|
||||
返回访问令牌和刷新令牌。
|
||||
"""
|
||||
try:
|
||||
_, tokens = await auth_service.login(
|
||||
username=login_data.username,
|
||||
password=login_data.password,
|
||||
)
|
||||
return APIResponse.ok(data=tokens, message="登录成功")
|
||||
except InvalidCredentialsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=e.message,
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except UserDisabledError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout",
|
||||
response_model=APIResponse[None],
|
||||
summary="用户退出",
|
||||
description="退出登录(客户端应删除本地令牌)",
|
||||
)
|
||||
async def logout(
|
||||
current_user: ActiveUser,
|
||||
) -> APIResponse[None]:
|
||||
"""
|
||||
用户退出接口
|
||||
|
||||
由于使用的是无状态 JWT,服务端不存储令牌,
|
||||
因此退出登录主要由客户端删除本地存储的令牌实现。
|
||||
|
||||
如果需要实现令牌黑名单,可以在后续版本中添加。
|
||||
"""
|
||||
# 可以在这里添加令牌黑名单逻辑
|
||||
# 或者记录退出日志
|
||||
return APIResponse.ok(message="退出成功")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
response_model=APIResponse[TokenResponse],
|
||||
summary="刷新令牌",
|
||||
description="使用刷新令牌获取新的访问令牌",
|
||||
)
|
||||
async def refresh_token(
|
||||
token_data: RefreshTokenRequest,
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> APIResponse[TokenResponse]:
|
||||
"""
|
||||
刷新令牌接口
|
||||
|
||||
使用刷新令牌获取新的访问令牌和刷新令牌。
|
||||
"""
|
||||
try:
|
||||
tokens = await auth_service.refresh_tokens(token_data.refresh_token)
|
||||
return APIResponse.ok(data=tokens, message="刷新成功")
|
||||
except TokenError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=e.message,
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
except UserDisabledError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/change-password",
|
||||
response_model=APIResponse[None],
|
||||
summary="修改密码",
|
||||
description="修改当前用户密码",
|
||||
)
|
||||
async def change_password(
|
||||
password_data: PasswordChangeRequest,
|
||||
current_user: ActiveUser,
|
||||
auth_service: Annotated[AuthService, Depends(get_auth_service)],
|
||||
) -> APIResponse[None]:
|
||||
"""
|
||||
修改密码接口
|
||||
|
||||
- **current_password**: 当前密码
|
||||
- **new_password**: 新密码
|
||||
"""
|
||||
try:
|
||||
await auth_service.change_password(
|
||||
user_id=current_user.id,
|
||||
password_data=password_data,
|
||||
)
|
||||
return APIResponse.ok(message="密码修改成功")
|
||||
except InvalidCredentialsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
except PasswordValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
270
app/api/v1/endpoints/balance.py
Normal file
270
app/api/v1/endpoints/balance.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
余额相关 API
|
||||
|
||||
包括余额查询、交易记录、兑换等接口。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, status
|
||||
|
||||
from app.api.deps import ActiveUser, DbSession
|
||||
from app.core.exceptions import AppException, ValidationError
|
||||
from app.schemas.base import APIResponse, PaginatedResponse
|
||||
from app.schemas.balance import (
|
||||
BalanceResponse,
|
||||
TransactionResponse,
|
||||
DeductionRequest,
|
||||
DeductionResponse,
|
||||
UNITS_PER_DISPLAY,
|
||||
format_display,
|
||||
)
|
||||
from app.schemas.redeem_code import RedeemRequest, RedeemResponse
|
||||
from app.services.balance import (
|
||||
BalanceService,
|
||||
InsufficientBalanceError,
|
||||
DuplicateTransactionError,
|
||||
)
|
||||
from app.services.redeem_code import (
|
||||
RedeemCodeService,
|
||||
RedeemCodeNotFoundError,
|
||||
RedeemCodeInvalidError,
|
||||
RedeemCodeExpiredError,
|
||||
RedeemCodeUsedError,
|
||||
RedeemCodeDisabledError,
|
||||
)
|
||||
from app.models.balance import TransactionType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 余额查询
|
||||
# ============================================================
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=APIResponse[BalanceResponse],
|
||||
summary="获取当前用户余额",
|
||||
description="获取当前登录用户的余额信息",
|
||||
)
|
||||
async def get_my_balance(
|
||||
current_user: ActiveUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[BalanceResponse]:
|
||||
"""
|
||||
获取当前用户余额
|
||||
|
||||
返回:
|
||||
- **balance**: 当前总余额
|
||||
- **available_balance**: 可用余额(总余额 - 冻结)
|
||||
- **frozen_balance**: 冻结余额
|
||||
- **total_recharged**: 累计充值
|
||||
- **total_consumed**: 累计消费
|
||||
"""
|
||||
balance_service = BalanceService(session)
|
||||
balance = await balance_service.get_balance(current_user.id)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=BalanceResponse(
|
||||
user_id=balance.user_id,
|
||||
balance_units=balance.balance,
|
||||
frozen_units=balance.frozen_balance,
|
||||
total_recharged_units=balance.total_recharged,
|
||||
total_consumed_units=balance.total_consumed,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 交易记录
|
||||
# ============================================================
|
||||
|
||||
@router.get(
|
||||
"/transactions",
|
||||
response_model=APIResponse[PaginatedResponse[TransactionResponse]],
|
||||
summary="获取交易记录",
|
||||
description="获取当前用户的余额交易记录",
|
||||
)
|
||||
async def get_my_transactions(
|
||||
current_user: ActiveUser,
|
||||
session: DbSession,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
transaction_type: TransactionType | None = Query(
|
||||
None, description="交易类型过滤"
|
||||
),
|
||||
) -> APIResponse[PaginatedResponse[TransactionResponse]]:
|
||||
"""
|
||||
获取交易记录
|
||||
|
||||
支持按交易类型过滤:
|
||||
- **recharge**: 充值
|
||||
- **deduction**: 扣款
|
||||
- **refund**: 退款
|
||||
- **adjustment**: 管理员调整
|
||||
"""
|
||||
balance_service = BalanceService(session)
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
transactions, total = await balance_service.get_transactions(
|
||||
current_user.id,
|
||||
offset=offset,
|
||||
limit=page_size,
|
||||
transaction_type=transaction_type,
|
||||
)
|
||||
|
||||
items = [
|
||||
TransactionResponse(
|
||||
id=t.id,
|
||||
transaction_type=t.transaction_type,
|
||||
status=t.status,
|
||||
amount_units=t.amount,
|
||||
balance_before_units=t.balance_before,
|
||||
balance_after_units=t.balance_after,
|
||||
reference_type=t.reference_type,
|
||||
reference_id=t.reference_id,
|
||||
description=t.description,
|
||||
created_at=t.created_at,
|
||||
)
|
||||
for t in transactions
|
||||
]
|
||||
|
||||
return APIResponse.ok(
|
||||
data=PaginatedResponse.create(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 兑换码兑换
|
||||
# ============================================================
|
||||
|
||||
@router.post(
|
||||
"/redeem",
|
||||
response_model=APIResponse[RedeemResponse],
|
||||
summary="兑换余额",
|
||||
description="使用兑换码充值余额",
|
||||
)
|
||||
async def redeem_code(
|
||||
request: Request,
|
||||
redeem_request: RedeemRequest,
|
||||
current_user: ActiveUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[RedeemResponse]:
|
||||
"""
|
||||
使用兑换码充值余额
|
||||
|
||||
- **code**: 兑换码(格式如 XXXX-XXXX-XXXX-XXXX)
|
||||
|
||||
返回兑换结果,包括充值金额和余额变化。
|
||||
"""
|
||||
redeem_service = RedeemCodeService(session)
|
||||
|
||||
# 获取客户端信息
|
||||
ip_address = request.client.host if request.client else None
|
||||
user_agent = request.headers.get("user-agent")
|
||||
|
||||
try:
|
||||
result = await redeem_service.redeem(
|
||||
current_user.id,
|
||||
redeem_request.code,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=RedeemResponse(
|
||||
success=True,
|
||||
message="兑换成功",
|
||||
face_value=result["face_value"],
|
||||
balance_before=result["balance_before"],
|
||||
balance_after=result["balance_after"],
|
||||
),
|
||||
message="兑换成功",
|
||||
)
|
||||
|
||||
except RedeemCodeNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=e.message,
|
||||
)
|
||||
except (
|
||||
RedeemCodeInvalidError,
|
||||
RedeemCodeExpiredError,
|
||||
RedeemCodeUsedError,
|
||||
RedeemCodeDisabledError,
|
||||
) as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 扣款(内部 API / 服务间调用)
|
||||
# ============================================================
|
||||
|
||||
@router.post(
|
||||
"/deduct",
|
||||
response_model=APIResponse[DeductionResponse],
|
||||
summary="扣款",
|
||||
description="从当前用户余额中扣款(通常由其他服务调用)",
|
||||
include_in_schema=True, # 可设为 False 隐藏此 API
|
||||
)
|
||||
async def deduct_balance(
|
||||
deduction_request: DeductionRequest,
|
||||
current_user: ActiveUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[DeductionResponse]:
|
||||
"""
|
||||
扣款
|
||||
|
||||
从当前用户余额中扣除指定金额。
|
||||
|
||||
- **amount**: 扣款金额(如 1.00)
|
||||
- **reference_type**: 关联业务类型(如 api_call)
|
||||
- **reference_id**: 关联业务 ID
|
||||
- **description**: 交易描述
|
||||
- **idempotency_key**: 幂等键(防止重复扣款)
|
||||
"""
|
||||
balance_service = BalanceService(session)
|
||||
|
||||
try:
|
||||
transaction = await balance_service.deduct(
|
||||
current_user.id,
|
||||
deduction_request.amount_units,
|
||||
reference_type=deduction_request.reference_type,
|
||||
reference_id=deduction_request.reference_id,
|
||||
description=deduction_request.description,
|
||||
idempotency_key=deduction_request.idempotency_key,
|
||||
)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=DeductionResponse(
|
||||
transaction_id=transaction.id,
|
||||
amount=format_display(abs(transaction.amount)),
|
||||
balance_before=format_display(transaction.balance_before),
|
||||
balance_after=format_display(transaction.balance_after),
|
||||
),
|
||||
message="扣款成功",
|
||||
)
|
||||
|
||||
except InsufficientBalanceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=e.message,
|
||||
)
|
||||
except DuplicateTransactionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=e.message,
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
157
app/api/v1/endpoints/oauth2.py
Normal file
157
app/api/v1/endpoints/oauth2.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
OAuth2 认证 API
|
||||
|
||||
提供 OAuth2 第三方登录功能。
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import DbSession
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import AuthenticationError
|
||||
from app.schemas.base import APIResponse
|
||||
from app.schemas.oauth2 import OAuth2AuthorizeResponse, OAuth2LoginResponse
|
||||
from app.services.oauth2 import OAuth2Service, OAuth2StateError
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_oauth2_service(session: DbSession) -> OAuth2Service:
|
||||
"""获取 OAuth2 服务实例"""
|
||||
return OAuth2Service(session)
|
||||
|
||||
|
||||
def _get_redirect_uri(request: Request) -> str:
|
||||
"""
|
||||
获取 OAuth2 回调 URL
|
||||
|
||||
根据请求动态构建完整的回调 URL
|
||||
"""
|
||||
# 优先使用 X-Forwarded 头(反向代理场景)
|
||||
scheme = request.headers.get("X-Forwarded-Proto", request.url.scheme)
|
||||
host = request.headers.get("X-Forwarded-Host", request.url.netloc)
|
||||
|
||||
return f"{scheme}://{host}{settings.oauth2_callback_path}"
|
||||
|
||||
|
||||
@router.get(
|
||||
"/authorize",
|
||||
response_model=APIResponse[OAuth2AuthorizeResponse],
|
||||
summary="获取 OAuth2 授权 URL",
|
||||
description="获取第三方登录授权页面 URL",
|
||||
)
|
||||
async def get_authorize_url(
|
||||
request: Request,
|
||||
session: DbSession,
|
||||
) -> APIResponse[OAuth2AuthorizeResponse]:
|
||||
"""
|
||||
获取 OAuth2 授权 URL
|
||||
|
||||
返回授权页面 URL 和状态码,客户端应重定向用户到该 URL。
|
||||
|
||||
流程:
|
||||
1. 前端调用此接口获取授权 URL
|
||||
2. 前端将用户重定向到授权 URL
|
||||
3. 用户在第三方平台完成授权
|
||||
4. 第三方平台重定向回回调 URL,携带 code 和 state
|
||||
5. 前端调用 /callback 接口完成登录
|
||||
"""
|
||||
# 检查 OAuth2 是否已配置
|
||||
if not settings.oauth2_client_id or not settings.oauth2_client_secret:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="OAuth2 登录未配置",
|
||||
)
|
||||
|
||||
oauth2_service = OAuth2Service(session)
|
||||
redirect_uri = _get_redirect_uri(request)
|
||||
|
||||
authorize_url, state = oauth2_service.generate_authorize_url(redirect_uri)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=OAuth2AuthorizeResponse(
|
||||
authorize_url=authorize_url,
|
||||
state=state,
|
||||
),
|
||||
message="请重定向到授权 URL",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/callback",
|
||||
response_model=APIResponse[OAuth2LoginResponse],
|
||||
summary="OAuth2 回调",
|
||||
description="处理 OAuth2 授权回调,完成登录",
|
||||
)
|
||||
async def oauth2_callback(
|
||||
request: Request,
|
||||
code: Annotated[str, Query(description="授权码")],
|
||||
state: Annotated[str, Query(description="状态码")],
|
||||
session: DbSession,
|
||||
) -> APIResponse[OAuth2LoginResponse]:
|
||||
"""
|
||||
OAuth2 回调接口
|
||||
|
||||
处理第三方平台的授权回调:
|
||||
1. 验证 state 防止 CSRF 攻击
|
||||
2. 用 code 换取访问令牌
|
||||
3. 获取用户信息
|
||||
4. 创建或关联本地用户
|
||||
5. 返回 JWT 令牌
|
||||
"""
|
||||
oauth2_service = OAuth2Service(session)
|
||||
redirect_uri = _get_redirect_uri(request)
|
||||
|
||||
try:
|
||||
user, tokens, is_new_user = await oauth2_service.authenticate(
|
||||
code=code,
|
||||
state=state,
|
||||
redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
return APIResponse.ok(
|
||||
data=OAuth2LoginResponse(
|
||||
access_token=tokens.access_token,
|
||||
refresh_token=tokens.refresh_token,
|
||||
token_type=tokens.token_type,
|
||||
expires_in=tokens.expires_in,
|
||||
is_new_user=is_new_user,
|
||||
),
|
||||
message="登录成功" if not is_new_user else "注册并登录成功",
|
||||
)
|
||||
|
||||
except OAuth2StateError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=e.message,
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/callback",
|
||||
response_model=APIResponse[OAuth2LoginResponse],
|
||||
summary="OAuth2 回调 (POST)",
|
||||
description="处理 OAuth2 授权回调(POST 方式)",
|
||||
)
|
||||
async def oauth2_callback_post(
|
||||
request: Request,
|
||||
code: Annotated[str, Query(description="授权码")],
|
||||
state: Annotated[str, Query(description="状态码")],
|
||||
session: DbSession,
|
||||
) -> APIResponse[OAuth2LoginResponse]:
|
||||
"""
|
||||
OAuth2 回调接口(POST 方式)
|
||||
|
||||
某些场景下前端可能使用 POST 方式调用回调。
|
||||
逻辑与 GET 方式相同。
|
||||
"""
|
||||
return await oauth2_callback(request, code, state, session)
|
||||
|
||||
103
app/api/v1/endpoints/users.py
Normal file
103
app/api/v1/endpoints/users.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
用户相关 API
|
||||
|
||||
包括获取用户信息、更新用户资料等接口。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
|
||||
from app.api.deps import ActiveUser, DbSession
|
||||
from app.core.exceptions import ResourceConflictError, UserNotFoundError
|
||||
from app.schemas.base import APIResponse
|
||||
from app.schemas.user import UserResponse, UserUpdate
|
||||
from app.services.user import UserService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/me",
|
||||
response_model=APIResponse[UserResponse],
|
||||
summary="获取当前用户信息",
|
||||
description="获取当前登录用户的详细信息",
|
||||
)
|
||||
async def get_current_user_info(
|
||||
current_user: ActiveUser,
|
||||
) -> APIResponse[UserResponse]:
|
||||
"""
|
||||
获取当前用户信息
|
||||
|
||||
返回当前登录用户的完整信息。
|
||||
"""
|
||||
return APIResponse.ok(
|
||||
data=UserResponse.model_validate(current_user),
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/me",
|
||||
response_model=APIResponse[UserResponse],
|
||||
summary="更新当前用户信息",
|
||||
description="更新当前登录用户的资料",
|
||||
)
|
||||
async def update_current_user(
|
||||
update_data: UserUpdate,
|
||||
current_user: ActiveUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[UserResponse]:
|
||||
"""
|
||||
更新当前用户信息
|
||||
|
||||
支持更新:
|
||||
- **nickname**: 昵称
|
||||
- **email**: 邮箱
|
||||
- **avatar_url**: 头像 URL
|
||||
- **bio**: 个人简介
|
||||
"""
|
||||
user_service = UserService(session)
|
||||
|
||||
try:
|
||||
user = await user_service.update_user(
|
||||
user_id=current_user.id,
|
||||
update_data=update_data,
|
||||
)
|
||||
return APIResponse.ok(
|
||||
data=UserResponse.model_validate(user),
|
||||
message="更新成功",
|
||||
)
|
||||
except ResourceConflictError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=e.message,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{user_id}",
|
||||
response_model=APIResponse[UserResponse],
|
||||
summary="获取指定用户信息",
|
||||
description="获取指定用户的公开信息",
|
||||
)
|
||||
async def get_user_by_id(
|
||||
user_id: str,
|
||||
current_user: ActiveUser,
|
||||
session: DbSession,
|
||||
) -> APIResponse[UserResponse]:
|
||||
"""
|
||||
获取指定用户信息
|
||||
|
||||
- **user_id**: 用户 ID
|
||||
"""
|
||||
user_service = UserService(session)
|
||||
|
||||
try:
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
return APIResponse.ok(
|
||||
data=UserResponse.model_validate(user),
|
||||
)
|
||||
except UserNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在",
|
||||
)
|
||||
|
||||
45
app/api/v1/router.py
Normal file
45
app/api/v1/router.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
API v1 路由聚合
|
||||
|
||||
汇总所有 v1 版本的 API 路由。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.v1.endpoints import auth, oauth2, users, balance
|
||||
from app.api.v1.endpoints.admin import redeem_codes as admin_redeem_codes
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册子路由
|
||||
api_router.include_router(
|
||||
auth.router,
|
||||
prefix="/auth",
|
||||
tags=["认证"],
|
||||
)
|
||||
|
||||
api_router.include_router(
|
||||
oauth2.router,
|
||||
prefix="/auth/oauth2",
|
||||
tags=["OAuth2 登录"],
|
||||
)
|
||||
|
||||
api_router.include_router(
|
||||
users.router,
|
||||
prefix="/users",
|
||||
tags=["用户"],
|
||||
)
|
||||
|
||||
api_router.include_router(
|
||||
balance.router,
|
||||
prefix="/balance",
|
||||
tags=["余额"],
|
||||
)
|
||||
|
||||
# 管理员路由
|
||||
api_router.include_router(
|
||||
admin_redeem_codes.router,
|
||||
prefix="/admin/redeem-codes",
|
||||
tags=["管理员 - 兑换码"],
|
||||
)
|
||||
|
||||
29
app/core/__init__.py
Normal file
29
app/core/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""核心功能模块"""
|
||||
|
||||
from .config import Settings, get_settings, settings
|
||||
from .config_loader import (
|
||||
ConfigLoader,
|
||||
ConfigSource,
|
||||
ConfigSourceError,
|
||||
EnvConfigSource,
|
||||
YamlConfigSource,
|
||||
config_loader,
|
||||
create_config_loader,
|
||||
get_config_loader,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 配置系统
|
||||
"Settings",
|
||||
"get_settings",
|
||||
"settings",
|
||||
# 配置加载器
|
||||
"ConfigLoader",
|
||||
"ConfigSource",
|
||||
"ConfigSourceError",
|
||||
"EnvConfigSource",
|
||||
"YamlConfigSource",
|
||||
"config_loader",
|
||||
"create_config_loader",
|
||||
"get_config_loader",
|
||||
]
|
||||
202
app/core/config.py
Normal file
202
app/core/config.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""
|
||||
应用配置管理
|
||||
|
||||
使用 Pydantic Settings 进行类型安全的配置管理,
|
||||
集成 ConfigLoader 支持环境变量和 YAML 配置文件。
|
||||
|
||||
配置优先级(从高到低):
|
||||
1. 环境变量(前缀 SATONANO_)
|
||||
2. config.yaml 文件
|
||||
3. .env 文件
|
||||
4. 默认值
|
||||
"""
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, computed_field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from .config_loader import create_config_loader
|
||||
|
||||
|
||||
def _yaml_settings_source() -> dict[str, Any]:
|
||||
"""
|
||||
YAML 配置源工厂函数
|
||||
|
||||
为 Pydantic Settings 提供 YAML 配置数据
|
||||
"""
|
||||
# 查找配置文件路径
|
||||
config_paths = [
|
||||
Path("config.yaml"),
|
||||
Path("config.yml"),
|
||||
Path(__file__).parent.parent.parent / "config.yaml",
|
||||
Path(__file__).parent.parent.parent / "config.yml",
|
||||
]
|
||||
|
||||
yaml_path = None
|
||||
for path in config_paths:
|
||||
if path.exists():
|
||||
yaml_path = path
|
||||
break
|
||||
|
||||
if yaml_path is None:
|
||||
return {}
|
||||
|
||||
# 使用 ConfigLoader 加载配置(不带环境变量前缀,让 pydantic 处理)
|
||||
loader = create_config_loader(
|
||||
yaml_path=yaml_path,
|
||||
env_prefix="", # 不读取环境变量,由 pydantic-settings 处理
|
||||
yaml_required=False,
|
||||
)
|
||||
return loader.load()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""应用配置"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="SATONANO_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# 应用基础配置
|
||||
app_name: str = "SatoNano"
|
||||
app_version: str = "0.1.0"
|
||||
debug: bool = False
|
||||
environment: Literal["development", "staging", "production"] = "development"
|
||||
|
||||
# API 配置
|
||||
api_v1_prefix: str = "/api/v1"
|
||||
|
||||
# 安全配置
|
||||
secret_key: str = Field(
|
||||
default="CHANGE-THIS-SECRET-KEY-IN-PRODUCTION",
|
||||
description="JWT 签名密钥,生产环境必须更改",
|
||||
)
|
||||
algorithm: str = "HS256"
|
||||
access_token_expire_minutes: int = 30
|
||||
refresh_token_expire_days: int = 7
|
||||
|
||||
# 数据库配置
|
||||
database_type: Literal["sqlite", "mysql", "postgresql"] = "sqlite"
|
||||
database_host: str = "localhost"
|
||||
database_port: int = 3306
|
||||
database_name: str = "satonano"
|
||||
database_username: str = ""
|
||||
database_password: str = ""
|
||||
database_sqlite_path: str = "./satonano.db"
|
||||
database_echo: bool = False
|
||||
|
||||
# OAuth2 配置 (Linux.do)
|
||||
oauth2_client_id: str = ""
|
||||
oauth2_client_secret: str = ""
|
||||
oauth2_callback_path: str = "/oauth2/callback"
|
||||
# 首选端点
|
||||
oauth2_authorize_endpoint: str = "https://connect.linux.do/oauth2/authorize"
|
||||
oauth2_token_endpoint: str = "https://connect.linux.do/oauth2/token"
|
||||
oauth2_user_info_endpoint: str = "https://connect.linux.do/api/user"
|
||||
# 备用端点
|
||||
oauth2_authorize_endpoint_reserve: str = "https://connect.linuxdo.org/oauth2/authorize"
|
||||
oauth2_token_endpoint_reserve: str = "https://connect.linuxdo.org/oauth2/token"
|
||||
oauth2_user_info_endpoint_reserve: str = "https://connect.linuxdo.org/api/user"
|
||||
oauth2_request_timeout: int = 10 # 请求超时时间(秒)
|
||||
|
||||
# 密码策略
|
||||
password_min_length: int = 8
|
||||
password_max_length: int = 128
|
||||
password_require_uppercase: bool = True
|
||||
password_require_lowercase: bool = True
|
||||
password_require_digit: bool = True
|
||||
password_require_special: bool = False
|
||||
|
||||
# 用户名策略
|
||||
username_min_length: int = 3
|
||||
username_max_length: int = 32
|
||||
|
||||
# 前端静态文件配置
|
||||
frontend_static_path: str = "./frontend/out"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""
|
||||
动态构建数据库连接 URL
|
||||
|
||||
根据 database_type 自动生成对应的连接字符串
|
||||
"""
|
||||
if self.database_type == "sqlite":
|
||||
return f"sqlite+aiosqlite:///{self.database_sqlite_path}"
|
||||
|
||||
if self.database_type == "mysql":
|
||||
return (
|
||||
f"mysql+aiomysql://{self.database_username}:{self.database_password}"
|
||||
f"@{self.database_host}:{self.database_port}/{self.database_name}"
|
||||
)
|
||||
|
||||
if self.database_type == "postgresql":
|
||||
return (
|
||||
f"postgresql+asyncpg://{self.database_username}:{self.database_password}"
|
||||
f"@{self.database_host}:{self.database_port}/{self.database_name}"
|
||||
)
|
||||
|
||||
return f"sqlite+aiosqlite:///{self.database_sqlite_path}"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def is_production(self) -> bool:
|
||||
"""是否为生产环境"""
|
||||
return self.environment == "production"
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: Any,
|
||||
env_settings: Any,
|
||||
dotenv_settings: Any,
|
||||
file_secret_settings: Any,
|
||||
) -> tuple[Any, ...]:
|
||||
"""
|
||||
自定义配置源优先级
|
||||
|
||||
优先级顺序(从高到低):
|
||||
1. 初始化参数
|
||||
2. 环境变量
|
||||
3. YAML 文件
|
||||
4. .env 文件
|
||||
5. 文件密钥
|
||||
"""
|
||||
|
||||
class YamlConfigSettingsSource:
|
||||
"""YAML 配置源适配器"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings]) -> None:
|
||||
self.settings_cls = settings_cls
|
||||
self._yaml_data: dict[str, Any] | None = None
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
if self._yaml_data is None:
|
||||
self._yaml_data = _yaml_settings_source()
|
||||
return self._yaml_data
|
||||
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
dotenv_settings,
|
||||
file_secret_settings,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""获取配置单例"""
|
||||
return Settings()
|
||||
|
||||
|
||||
settings = get_settings()
|
||||
387
app/core/config_loader.py
Normal file
387
app/core/config_loader.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""
|
||||
配置加载器模块
|
||||
|
||||
提供统一的配置加载机制,支持多数据源配置合并:
|
||||
- 环境变量(优先级最高)
|
||||
- YAML 配置文件
|
||||
- 默认值(优先级最低)
|
||||
|
||||
设计原则:
|
||||
- 单一职责:仅负责配置的加载和合并
|
||||
- 依赖倒置:通过抽象接口解耦配置源
|
||||
- 开闭原则:易于扩展新的配置源
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import yaml
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ConfigSourceError(Exception):
|
||||
"""配置源加载错误"""
|
||||
|
||||
def __init__(self, source: str, message: str) -> None:
|
||||
self.source = source
|
||||
self.message = message
|
||||
super().__init__(f"[{source}] {message}")
|
||||
|
||||
|
||||
class ConfigSource(ABC):
|
||||
"""配置源抽象基类"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""配置源名称"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def priority(self) -> int:
|
||||
"""优先级,数值越大优先级越高"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> dict[str, Any]:
|
||||
"""加载配置,返回扁平化的键值对"""
|
||||
...
|
||||
|
||||
|
||||
class EnvConfigSource(ConfigSource):
|
||||
"""环境变量配置源"""
|
||||
|
||||
def __init__(self, prefix: str = "") -> None:
|
||||
"""
|
||||
初始化环境变量配置源
|
||||
|
||||
Args:
|
||||
prefix: 环境变量前缀,用于过滤相关配置
|
||||
例如 prefix="SATONANO_" 时只读取以此为前缀的变量
|
||||
"""
|
||||
self._prefix = prefix.upper()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "environment"
|
||||
|
||||
@property
|
||||
def priority(self) -> int:
|
||||
return 100 # 最高优先级
|
||||
|
||||
def load(self) -> dict[str, Any]:
|
||||
"""
|
||||
加载环境变量
|
||||
|
||||
Returns:
|
||||
配置字典,键名转换为小写并移除前缀
|
||||
"""
|
||||
config: dict[str, Any] = {}
|
||||
prefix_len = len(self._prefix)
|
||||
|
||||
for key, value in os.environ.items():
|
||||
if self._prefix and not key.upper().startswith(self._prefix):
|
||||
continue
|
||||
|
||||
# 移除前缀并转为小写
|
||||
config_key = key[prefix_len:].lower() if self._prefix else key.lower()
|
||||
config[config_key] = self._parse_value(value)
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _parse_value(value: str) -> Any:
|
||||
"""
|
||||
解析环境变量值,自动转换类型
|
||||
|
||||
Args:
|
||||
value: 原始字符串值
|
||||
|
||||
Returns:
|
||||
转换后的值(bool/int/float/str)
|
||||
"""
|
||||
# 布尔值
|
||||
if value.lower() in ("true", "yes", "1", "on"):
|
||||
return True
|
||||
if value.lower() in ("false", "no", "0", "off"):
|
||||
return False
|
||||
|
||||
# 整数
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 浮点数
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class YamlConfigSource(ConfigSource):
|
||||
"""YAML 配置文件源"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str | Path = "config.yaml",
|
||||
required: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
初始化 YAML 配置源
|
||||
|
||||
Args:
|
||||
file_path: 配置文件路径
|
||||
required: 是否必须存在,为 True 时文件不存在会抛出异常
|
||||
"""
|
||||
self._file_path = Path(file_path)
|
||||
self._required = required
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"yaml:{self._file_path}"
|
||||
|
||||
@property
|
||||
def priority(self) -> int:
|
||||
return 50 # 中等优先级
|
||||
|
||||
def load(self) -> dict[str, Any]:
|
||||
"""
|
||||
加载 YAML 配置文件
|
||||
|
||||
Returns:
|
||||
配置字典
|
||||
|
||||
Raises:
|
||||
ConfigSourceError: 文件不存在(required=True)或解析错误
|
||||
"""
|
||||
if not self._file_path.exists():
|
||||
if self._required:
|
||||
raise ConfigSourceError(
|
||||
self.name,
|
||||
f"配置文件不存在: {self._file_path.absolute()}",
|
||||
)
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(self._file_path, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
except yaml.YAMLError as e:
|
||||
raise ConfigSourceError(self.name, f"YAML 解析错误: {e}") from e
|
||||
|
||||
if data is None:
|
||||
return {}
|
||||
|
||||
# 处理列表格式(兼容用户示例格式)
|
||||
if isinstance(data, list):
|
||||
return self._flatten_list(data)
|
||||
|
||||
# 标准字典格式
|
||||
if isinstance(data, dict):
|
||||
return self._normalize_keys(data)
|
||||
|
||||
raise ConfigSourceError(
|
||||
self.name,
|
||||
f"不支持的配置格式,期望 dict 或 list,得到 {type(data).__name__}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _flatten_list(items: list) -> dict[str, Any]:
|
||||
"""
|
||||
将列表格式配置展平为字典
|
||||
|
||||
支持格式:
|
||||
- key: value
|
||||
- key: value
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
for key, value in item.items():
|
||||
result[key.lower()] = value
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _normalize_keys(data: dict) -> dict[str, Any]:
|
||||
"""键名标准化为小写"""
|
||||
return {k.lower(): v for k, v in data.items()}
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
配置加载器
|
||||
|
||||
负责从多个配置源加载并合并配置,按优先级覆盖。
|
||||
|
||||
Example:
|
||||
>>> loader = ConfigLoader()
|
||||
>>> loader.add_source(YamlConfigSource("config.yaml"))
|
||||
>>> loader.add_source(EnvConfigSource("SATONANO_"))
|
||||
>>> config = loader.load()
|
||||
>>> db_type = loader.get("database_type", "sqlite")
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._sources: list[ConfigSource] = []
|
||||
self._config: dict[str, Any] = {}
|
||||
self._loaded = False
|
||||
|
||||
def add_source(self, source: ConfigSource) -> ConfigLoader:
|
||||
"""
|
||||
添加配置源
|
||||
|
||||
Args:
|
||||
source: 配置源实例
|
||||
|
||||
Returns:
|
||||
self,支持链式调用
|
||||
"""
|
||||
self._sources.append(source)
|
||||
self._loaded = False # 标记需要重新加载
|
||||
return self
|
||||
|
||||
def load(self) -> dict[str, Any]:
|
||||
"""
|
||||
加载并合并所有配置源
|
||||
|
||||
Returns:
|
||||
合并后的配置字典
|
||||
"""
|
||||
# 按优先级升序排序,后加载的覆盖先加载的
|
||||
sorted_sources = sorted(self._sources, key=lambda s: s.priority)
|
||||
|
||||
self._config = {}
|
||||
for source in sorted_sources:
|
||||
source_config = source.load()
|
||||
self._config.update(source_config)
|
||||
|
||||
self._loaded = True
|
||||
return self._config.copy()
|
||||
|
||||
def get(self, key: str, default: T = None) -> T | Any:
|
||||
"""
|
||||
获取配置值
|
||||
|
||||
Args:
|
||||
key: 配置键名(不区分大小写)
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
配置值或默认值
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
return self._config.get(key.lower(), default)
|
||||
|
||||
def get_str(self, key: str, default: str = "") -> str:
|
||||
"""获取字符串配置"""
|
||||
value = self.get(key, default)
|
||||
return str(value) if value is not None else default
|
||||
|
||||
def get_int(self, key: str, default: int = 0) -> int:
|
||||
"""获取整数配置"""
|
||||
value = self.get(key, default)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
def get_float(self, key: str, default: float = 0.0) -> float:
|
||||
"""获取浮点数配置"""
|
||||
value = self.get(key, default)
|
||||
if isinstance(value, float):
|
||||
return value
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return default
|
||||
|
||||
def get_bool(self, key: str, default: bool = False) -> bool:
|
||||
"""获取布尔配置"""
|
||||
value = self.get(key, default)
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "yes", "1", "on")
|
||||
return bool(value)
|
||||
|
||||
def require(self, key: str) -> Any:
|
||||
"""
|
||||
获取必需的配置值
|
||||
|
||||
Args:
|
||||
key: 配置键名
|
||||
|
||||
Returns:
|
||||
配置值
|
||||
|
||||
Raises:
|
||||
KeyError: 配置不存在
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
|
||||
key_lower = key.lower()
|
||||
if key_lower not in self._config:
|
||||
raise KeyError(f"缺少必需的配置项: {key}")
|
||||
return self._config[key_lower]
|
||||
|
||||
def all(self) -> dict[str, Any]:
|
||||
"""获取所有配置的副本"""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
return self._config.copy()
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""检查配置是否存在"""
|
||||
if not self._loaded:
|
||||
self.load()
|
||||
return key.lower() in self._config
|
||||
|
||||
def __repr__(self) -> str:
|
||||
sources = ", ".join(s.name for s in self._sources)
|
||||
return f"ConfigLoader(sources=[{sources}], loaded={self._loaded})"
|
||||
|
||||
|
||||
def create_config_loader(
|
||||
yaml_path: str | Path = "config.yaml",
|
||||
env_prefix: str = "SATONANO_",
|
||||
yaml_required: bool = False,
|
||||
) -> ConfigLoader:
|
||||
"""
|
||||
创建预配置的配置加载器
|
||||
|
||||
Args:
|
||||
yaml_path: YAML 配置文件路径
|
||||
env_prefix: 环境变量前缀
|
||||
yaml_required: YAML 文件是否必须存在
|
||||
|
||||
Returns:
|
||||
配置好的 ConfigLoader 实例
|
||||
"""
|
||||
loader = ConfigLoader()
|
||||
loader.add_source(YamlConfigSource(yaml_path, required=yaml_required))
|
||||
loader.add_source(EnvConfigSource(env_prefix))
|
||||
return loader
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_config_loader() -> ConfigLoader:
|
||||
"""获取全局配置加载器单例"""
|
||||
return create_config_loader()
|
||||
|
||||
|
||||
# 便捷访问的全局实例
|
||||
config_loader = get_config_loader()
|
||||
|
||||
224
app/core/exceptions.py
Normal file
224
app/core/exceptions.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
自定义异常类
|
||||
|
||||
定义业务层面的异常,便于统一处理和返回合适的 HTTP 响应。
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class AppException(Exception):
|
||||
"""应用基础异常"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: str = "APP_ERROR",
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class AuthenticationError(AppException):
|
||||
"""认证错误"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "认证失败",
|
||||
code: str = "AUTHENTICATION_ERROR",
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class InvalidCredentialsError(AuthenticationError):
|
||||
"""无效凭证"""
|
||||
|
||||
def __init__(self, message: str = "用户名或密码错误"):
|
||||
super().__init__(message, "INVALID_CREDENTIALS")
|
||||
|
||||
|
||||
class TokenError(AuthenticationError):
|
||||
"""令牌错误"""
|
||||
|
||||
def __init__(self, message: str = "令牌无效或已过期"):
|
||||
super().__init__(message, "TOKEN_ERROR")
|
||||
|
||||
|
||||
class TokenExpiredError(TokenError):
|
||||
"""令牌过期"""
|
||||
|
||||
def __init__(self, message: str = "令牌已过期"):
|
||||
super().__init__(message)
|
||||
self.code = "TOKEN_EXPIRED"
|
||||
|
||||
|
||||
class AuthorizationError(AppException):
|
||||
"""授权错误"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "权限不足",
|
||||
code: str = "AUTHORIZATION_ERROR",
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class ValidationError(AppException):
|
||||
"""验证错误"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "数据验证失败",
|
||||
code: str = "VALIDATION_ERROR",
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class ResourceNotFoundError(AppException):
|
||||
"""资源未找到"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "资源不存在",
|
||||
resource_type: str = "resource",
|
||||
resource_id: Any = None,
|
||||
):
|
||||
super().__init__(
|
||||
message,
|
||||
"RESOURCE_NOT_FOUND",
|
||||
{"resource_type": resource_type, "resource_id": resource_id},
|
||||
)
|
||||
|
||||
|
||||
class ResourceConflictError(AppException):
|
||||
"""资源冲突(如重复创建)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "资源已存在",
|
||||
code: str = "RESOURCE_CONFLICT",
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class UserNotFoundError(ResourceNotFoundError):
|
||||
"""用户不存在"""
|
||||
|
||||
def __init__(self, user_id: Any = None):
|
||||
super().__init__("用户不存在", "user", user_id)
|
||||
|
||||
|
||||
class UserAlreadyExistsError(ResourceConflictError):
|
||||
"""用户已存在"""
|
||||
|
||||
def __init__(self, field: str = "username"):
|
||||
super().__init__(
|
||||
f"该{field}已被注册",
|
||||
"USER_ALREADY_EXISTS",
|
||||
{"field": field},
|
||||
)
|
||||
|
||||
|
||||
class UserDisabledError(AuthenticationError):
|
||||
"""用户被禁用"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("账户已被禁用", "USER_DISABLED")
|
||||
|
||||
|
||||
class PasswordValidationError(ValidationError):
|
||||
"""密码验证错误"""
|
||||
|
||||
def __init__(self, message: str = "密码不符合要求"):
|
||||
super().__init__(message, "PASSWORD_VALIDATION_ERROR")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 余额相关异常
|
||||
# ============================================================
|
||||
|
||||
class InsufficientBalanceError(AppException):
|
||||
"""余额不足"""
|
||||
|
||||
def __init__(self, required: int, available: int):
|
||||
super().__init__(
|
||||
f"余额不足,需要 {required / 1000:.2f},当前可用 {available / 1000:.2f}",
|
||||
"INSUFFICIENT_BALANCE",
|
||||
{"required_units": required, "available_units": available},
|
||||
)
|
||||
|
||||
|
||||
class DuplicateTransactionError(AppException):
|
||||
"""重复交易"""
|
||||
|
||||
def __init__(self, idempotency_key: str):
|
||||
super().__init__(
|
||||
"该交易已处理",
|
||||
"DUPLICATE_TRANSACTION",
|
||||
{"idempotency_key": idempotency_key},
|
||||
)
|
||||
|
||||
|
||||
class ConcurrencyError(AppException):
|
||||
"""并发冲突"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"操作冲突,请重试",
|
||||
"CONCURRENCY_ERROR",
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 兑换码相关异常
|
||||
# ============================================================
|
||||
|
||||
class RedeemCodeNotFoundError(AppException):
|
||||
"""兑换码不存在"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码不存在",
|
||||
"REDEEM_CODE_NOT_FOUND",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeInvalidError(AppException):
|
||||
"""兑换码无效"""
|
||||
|
||||
def __init__(self, code: str, reason: str):
|
||||
super().__init__(
|
||||
f"兑换码无效: {reason}",
|
||||
"REDEEM_CODE_INVALID",
|
||||
{"code": code, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeExpiredError(AppException):
|
||||
"""兑换码已过期"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已过期",
|
||||
"REDEEM_CODE_EXPIRED",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeUsedError(AppException):
|
||||
"""兑换码已使用"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已使用",
|
||||
"REDEEM_CODE_USED",
|
||||
{"code": code},
|
||||
)
|
||||
155
app/core/security.py
Normal file
155
app/core/security.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
安全相关功能
|
||||
|
||||
包括密码哈希、JWT 令牌生成与验证。
|
||||
使用 Argon2 作为密码哈希算法(目前最安全的选择)。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import InvalidHashError, VerifyMismatchError
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Argon2 密码哈希器,使用推荐的安全参数
|
||||
_password_hasher = PasswordHasher(
|
||||
time_cost=3, # 迭代次数
|
||||
memory_cost=65536, # 内存使用 (64MB)
|
||||
parallelism=4, # 并行度
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""
|
||||
对密码进行哈希处理
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
|
||||
Returns:
|
||||
Argon2 哈希字符串
|
||||
"""
|
||||
return _password_hasher.hash(password)
|
||||
|
||||
|
||||
def verify_password(password: str, hashed_password: str) -> bool:
|
||||
"""
|
||||
验证密码是否匹配
|
||||
|
||||
Args:
|
||||
password: 明文密码
|
||||
hashed_password: 已哈希的密码
|
||||
|
||||
Returns:
|
||||
密码是否正确
|
||||
"""
|
||||
try:
|
||||
_password_hasher.verify(hashed_password, password)
|
||||
return True
|
||||
except (VerifyMismatchError, InvalidHashError):
|
||||
return False
|
||||
|
||||
|
||||
def password_needs_rehash(hashed_password: str) -> bool:
|
||||
"""
|
||||
检查密码是否需要重新哈希(参数升级时使用)
|
||||
|
||||
Args:
|
||||
hashed_password: 已哈希的密码
|
||||
|
||||
Returns:
|
||||
是否需要重新哈希
|
||||
"""
|
||||
return _password_hasher.check_needs_rehash(hashed_password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
subject: str | int,
|
||||
expires_delta: timedelta | None = None,
|
||||
extra_claims: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建访问令牌
|
||||
|
||||
Args:
|
||||
subject: 令牌主体(通常是用户ID)
|
||||
expires_delta: 过期时间增量
|
||||
extra_claims: 额外的声明数据
|
||||
|
||||
Returns:
|
||||
JWT 访问令牌
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if expires_delta:
|
||||
expire = now + expires_delta
|
||||
else:
|
||||
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
|
||||
|
||||
payload = {
|
||||
"sub": str(subject),
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
"type": "access",
|
||||
}
|
||||
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
subject: str | int,
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建刷新令牌
|
||||
|
||||
Args:
|
||||
subject: 令牌主体(通常是用户ID)
|
||||
expires_delta: 过期时间增量
|
||||
|
||||
Returns:
|
||||
JWT 刷新令牌
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if expires_delta:
|
||||
expire = now + expires_delta
|
||||
else:
|
||||
expire = now + timedelta(days=settings.refresh_token_expire_days)
|
||||
|
||||
payload = {
|
||||
"sub": str(subject),
|
||||
"iat": now,
|
||||
"exp": expire,
|
||||
"type": "refresh",
|
||||
}
|
||||
|
||||
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict[str, Any]:
|
||||
"""
|
||||
解码并验证 JWT 令牌
|
||||
|
||||
Args:
|
||||
token: JWT 令牌字符串
|
||||
|
||||
Returns:
|
||||
解码后的负载数据
|
||||
|
||||
Raises:
|
||||
jwt.InvalidTokenError: 令牌无效
|
||||
jwt.ExpiredSignatureError: 令牌已过期
|
||||
"""
|
||||
return jwt.decode(
|
||||
token,
|
||||
settings.secret_key,
|
||||
algorithms=[settings.algorithm],
|
||||
)
|
||||
|
||||
100
app/database.py
Normal file
100
app/database.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
数据库连接与会话管理
|
||||
|
||||
使用 SQLAlchemy 2.0 异步模式。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""SQLAlchemy 声明式基类"""
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_sqlite_dir() -> None:
|
||||
"""
|
||||
确保 SQLite 数据库目录存在
|
||||
|
||||
如果配置使用 SQLite,自动创建数据库文件所在的目录。
|
||||
"""
|
||||
if settings.database_type != "sqlite":
|
||||
return
|
||||
|
||||
db_path = Path(settings.database_sqlite_path)
|
||||
db_dir = db_path.parent
|
||||
|
||||
if not db_dir.exists():
|
||||
logger.info(f"创建数据库目录: {db_dir.absolute()}")
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 确保 SQLite 目录存在
|
||||
_ensure_sqlite_dir()
|
||||
|
||||
# 创建异步引擎
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.database_echo,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# 创建异步会话工厂
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
获取数据库会话(依赖注入用)
|
||||
|
||||
Yields:
|
||||
异步数据库会话
|
||||
"""
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db() -> None:
|
||||
"""初始化数据库(创建所有表)"""
|
||||
# 导入所有模型以确保它们被注册
|
||||
from app.models import ( # noqa: F401
|
||||
User,
|
||||
UserBalance,
|
||||
BalanceTransaction,
|
||||
RedeemCode,
|
||||
RedeemCodeBatch,
|
||||
RedeemCodeUsageLog,
|
||||
)
|
||||
|
||||
logger.info(f"初始化数据库: {settings.database_url}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
|
||||
async def close_db() -> None:
|
||||
"""关闭数据库连接"""
|
||||
await engine.dispose()
|
||||
|
||||
268
app/main.py
Normal file
268
app/main.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
FastAPI 应用入口
|
||||
|
||||
配置应用实例、中间件、异常处理器等。
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.api.v1.router import api_router
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
AppException,
|
||||
AuthenticationError,
|
||||
AuthorizationError,
|
||||
ResourceConflictError,
|
||||
ResourceNotFoundError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.database import close_db, init_db
|
||||
|
||||
|
||||
def get_frontend_dir() -> Path:
|
||||
"""
|
||||
获取前端静态文件目录
|
||||
|
||||
支持相对路径和绝对路径配置
|
||||
"""
|
||||
path = Path(settings.frontend_static_path)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
# 相对路径基于项目根目录
|
||||
return Path(__file__).parent.parent / path
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""
|
||||
应用生命周期管理
|
||||
|
||||
启动时初始化数据库,关闭时释放资源。
|
||||
"""
|
||||
# 启动时
|
||||
await init_db()
|
||||
yield
|
||||
# 关闭时
|
||||
await close_db()
|
||||
|
||||
|
||||
def create_application() -> FastAPI:
|
||||
"""
|
||||
创建 FastAPI 应用实例
|
||||
|
||||
Returns:
|
||||
配置完成的 FastAPI 应用
|
||||
"""
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description="现代化用户认证系统 API",
|
||||
docs_url="/docs" if not settings.is_production else None,
|
||||
redoc_url="/redoc" if not settings.is_production else None,
|
||||
openapi_url="/openapi.json" if not settings.is_production else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# 配置 CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"] if settings.debug else [],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册异常处理器
|
||||
register_exception_handlers(app)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(api_router, prefix=settings.api_v1_prefix)
|
||||
|
||||
# 健康检查端点
|
||||
@app.get("/health", tags=["健康检查"])
|
||||
async def health_check():
|
||||
"""健康检查接口"""
|
||||
return {"status": "healthy", "version": settings.app_version}
|
||||
|
||||
# 挂载前端静态文件(当静态文件目录存在时)
|
||||
frontend_dir = get_frontend_dir()
|
||||
if frontend_dir.exists():
|
||||
setup_frontend_routes(app, frontend_dir)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def setup_frontend_routes(app: FastAPI, frontend_dir: Path) -> None:
|
||||
"""
|
||||
配置前端静态文件路由
|
||||
|
||||
Args:
|
||||
app: FastAPI 应用实例
|
||||
frontend_dir: 前端静态文件目录
|
||||
"""
|
||||
# 挂载 _next 静态资源目录(Next.js 构建产物)
|
||||
next_static_dir = frontend_dir / "_next"
|
||||
if next_static_dir.exists():
|
||||
app.mount(
|
||||
"/_next",
|
||||
StaticFiles(directory=str(next_static_dir)),
|
||||
name="next_static",
|
||||
)
|
||||
|
||||
# 挂载 assets 目录(通用静态资源)
|
||||
assets_dir = frontend_dir / "assets"
|
||||
if assets_dir.exists():
|
||||
app.mount(
|
||||
"/assets",
|
||||
StaticFiles(directory=str(assets_dir)),
|
||||
name="assets",
|
||||
)
|
||||
|
||||
# 挂载 static 目录(通用静态资源)
|
||||
static_dir = frontend_dir / "static"
|
||||
if static_dir.exists():
|
||||
app.mount(
|
||||
"/static",
|
||||
StaticFiles(directory=str(static_dir)),
|
||||
name="static",
|
||||
)
|
||||
|
||||
# SPA 路由处理 - 必须在最后注册,以避免覆盖 API 路由
|
||||
@app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def serve_spa(request: Request, full_path: str):
|
||||
"""
|
||||
SPA 路由处理
|
||||
|
||||
- 对于静态文件请求,直接返回文件
|
||||
- 对于 SPA 路由请求,返回 index.html
|
||||
"""
|
||||
# 跳过 API 路由和健康检查
|
||||
if full_path.startswith("api/") or full_path == "health":
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"success": False, "message": "Not Found"},
|
||||
)
|
||||
|
||||
# 尝试查找静态文件
|
||||
file_path = frontend_dir / full_path
|
||||
|
||||
# 如果是目录,尝试查找 index.html
|
||||
if file_path.is_dir():
|
||||
file_path = file_path / "index.html"
|
||||
|
||||
# 如果文件存在,返回文件
|
||||
if file_path.is_file():
|
||||
return FileResponse(
|
||||
file_path,
|
||||
headers=_get_cache_headers(full_path),
|
||||
)
|
||||
|
||||
# 尝试添加 .html 后缀(Next.js 静态导出格式)
|
||||
html_path = frontend_dir / f"{full_path}.html"
|
||||
if html_path.is_file():
|
||||
return FileResponse(html_path)
|
||||
|
||||
# 默认返回 index.html(SPA 路由回退)
|
||||
index_path = frontend_dir / "index.html"
|
||||
if index_path.is_file():
|
||||
return FileResponse(index_path)
|
||||
|
||||
# 前端文件不存在
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"success": False, "message": "Not Found"},
|
||||
)
|
||||
|
||||
|
||||
def _get_cache_headers(path: str) -> dict[str, str]:
|
||||
"""
|
||||
根据文件类型返回适当的缓存头
|
||||
|
||||
Args:
|
||||
path: 请求路径
|
||||
|
||||
Returns:
|
||||
缓存相关的 HTTP 头
|
||||
"""
|
||||
# 静态资源(带 hash 的文件)使用长期缓存
|
||||
if "/_next/" in path or path.startswith("_next/"):
|
||||
return {"Cache-Control": "public, max-age=31536000, immutable"}
|
||||
|
||||
# 其他静态文件使用短期缓存
|
||||
static_extensions = {".js", ".css", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2"}
|
||||
if any(path.endswith(ext) for ext in static_extensions):
|
||||
return {"Cache-Control": "public, max-age=86400"}
|
||||
|
||||
# HTML 文件不缓存
|
||||
return {"Cache-Control": "no-cache"}
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
"""注册全局异常处理器"""
|
||||
|
||||
@app.exception_handler(AppException)
|
||||
async def app_exception_handler(
|
||||
request: Request,
|
||||
exc: AppException,
|
||||
) -> JSONResponse:
|
||||
"""应用异常处理"""
|
||||
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
if isinstance(exc, AuthenticationError):
|
||||
status_code = status.HTTP_401_UNAUTHORIZED
|
||||
elif isinstance(exc, AuthorizationError):
|
||||
status_code = status.HTTP_403_FORBIDDEN
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
status_code = status.HTTP_404_NOT_FOUND
|
||||
elif isinstance(exc, ResourceConflictError):
|
||||
status_code = status.HTTP_409_CONFLICT
|
||||
elif isinstance(exc, ValidationError):
|
||||
status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content={
|
||||
"success": False,
|
||||
"message": exc.message,
|
||||
"code": exc.code,
|
||||
"details": exc.details,
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: RequestValidationError,
|
||||
) -> JSONResponse:
|
||||
"""请求验证异常处理"""
|
||||
errors = []
|
||||
for error in exc.errors():
|
||||
field = ".".join(str(loc) for loc in error["loc"])
|
||||
errors.append({
|
||||
"field": field,
|
||||
"message": error["msg"],
|
||||
"type": error["type"],
|
||||
})
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content={
|
||||
"success": False,
|
||||
"message": "请求数据验证失败",
|
||||
"code": "VALIDATION_ERROR",
|
||||
"details": {"errors": errors},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# 创建应用实例
|
||||
app = create_application()
|
||||
|
||||
28
app/models/__init__.py
Normal file
28
app/models/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""数据库模型"""
|
||||
|
||||
from app.models.user import User
|
||||
from app.models.balance import (
|
||||
UserBalance,
|
||||
BalanceTransaction,
|
||||
TransactionType,
|
||||
TransactionStatus,
|
||||
)
|
||||
from app.models.redeem_code import (
|
||||
RedeemCode,
|
||||
RedeemCodeBatch,
|
||||
RedeemCodeUsageLog,
|
||||
RedeemCodeStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"UserBalance",
|
||||
"BalanceTransaction",
|
||||
"TransactionType",
|
||||
"TransactionStatus",
|
||||
"RedeemCode",
|
||||
"RedeemCodeBatch",
|
||||
"RedeemCodeUsageLog",
|
||||
"RedeemCodeStatus",
|
||||
]
|
||||
|
||||
320
app/models/balance.py
Normal file
320
app/models/balance.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
余额与交易模型
|
||||
|
||||
定义用户余额、余额交易记录相关的数据表结构。
|
||||
|
||||
设计说明:
|
||||
- 余额内部以无符号整数存储(单位额度),避免浮点精度问题
|
||||
- 1.00 显示余额 = 1000 单位额度(精度 0.001)
|
||||
- 所有金额操作都在单位额度层面进行,只在展示时转换
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""生成 UUID 字符串"""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""获取当前 UTC 时间"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class TransactionType(str, Enum):
|
||||
"""交易类型枚举"""
|
||||
|
||||
RECHARGE = "recharge" # 充值(兑换码等)
|
||||
DEDUCTION = "deduction" # 扣款(API 调用等)
|
||||
REFUND = "refund" # 退款
|
||||
ADJUSTMENT = "adjustment" # 管理员调整
|
||||
TRANSFER_IN = "transfer_in" # 转入
|
||||
TRANSFER_OUT = "transfer_out" # 转出
|
||||
|
||||
|
||||
class TransactionStatus(str, Enum):
|
||||
"""交易状态枚举"""
|
||||
|
||||
PENDING = "pending" # 待处理
|
||||
COMPLETED = "completed" # 已完成
|
||||
FAILED = "failed" # 失败
|
||||
CANCELLED = "cancelled" # 已取消
|
||||
|
||||
|
||||
class UserBalance(Base):
|
||||
"""
|
||||
用户余额模型
|
||||
|
||||
独立的余额表,便于扩展(如多币种、账户类型)和锁管理
|
||||
"""
|
||||
|
||||
__tablename__ = "user_balances"
|
||||
|
||||
# 主键
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=generate_uuid,
|
||||
comment="余额记录唯一标识",
|
||||
)
|
||||
|
||||
# 关联用户(一对一)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="关联用户 ID",
|
||||
)
|
||||
|
||||
# 余额信息(内部以整数单位存储)
|
||||
balance: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="当前余额(单位额度,1000 = 1.00 显示余额)",
|
||||
)
|
||||
|
||||
# 冻结金额(用于处理中的交易)
|
||||
frozen_balance: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="冻结余额(单位额度)",
|
||||
)
|
||||
|
||||
# 累计统计
|
||||
total_recharged: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="累计充值(单位额度)",
|
||||
)
|
||||
total_consumed: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="累计消费(单位额度)",
|
||||
)
|
||||
|
||||
# 乐观锁版本号
|
||||
version: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="版本号(乐观锁)",
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="创建时间",
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
onupdate=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="更新时间",
|
||||
)
|
||||
|
||||
# 关系
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
back_populates="balance_account",
|
||||
lazy="selectin",
|
||||
)
|
||||
transactions: Mapped[list["BalanceTransaction"]] = relationship(
|
||||
"BalanceTransaction",
|
||||
back_populates="balance_account",
|
||||
lazy="selectin",
|
||||
order_by="desc(BalanceTransaction.created_at)",
|
||||
)
|
||||
|
||||
@property
|
||||
def available_balance(self) -> int:
|
||||
"""可用余额(总余额 - 冻结余额)"""
|
||||
return self.balance - self.frozen_balance
|
||||
|
||||
@property
|
||||
def display_balance(self) -> str:
|
||||
"""显示余额(2 位小数)"""
|
||||
return f"{self.balance / 1000:.2f}"
|
||||
|
||||
@property
|
||||
def display_available_balance(self) -> str:
|
||||
"""显示可用余额(2 位小数)"""
|
||||
return f"{self.available_balance / 1000:.2f}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<UserBalance(user_id={self.user_id!r}, balance={self.display_balance})>"
|
||||
|
||||
|
||||
class BalanceTransaction(Base):
|
||||
"""
|
||||
余额交易记录模型
|
||||
|
||||
记录所有余额变动,用于审计和对账
|
||||
"""
|
||||
|
||||
__tablename__ = "balance_transactions"
|
||||
__table_args__ = (
|
||||
Index("ix_balance_transactions_user_created", "user_id", "created_at"),
|
||||
Index("ix_balance_transactions_type_status", "transaction_type", "status"),
|
||||
)
|
||||
|
||||
# 主键
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=generate_uuid,
|
||||
comment="交易记录唯一标识",
|
||||
)
|
||||
|
||||
# 关联
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="关联用户 ID",
|
||||
)
|
||||
balance_account_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("user_balances.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="关联余额账户 ID",
|
||||
)
|
||||
|
||||
# 交易信息
|
||||
transaction_type: Mapped[TransactionType] = mapped_column(
|
||||
SQLEnum(TransactionType),
|
||||
nullable=False,
|
||||
comment="交易类型",
|
||||
)
|
||||
status: Mapped[TransactionStatus] = mapped_column(
|
||||
SQLEnum(TransactionStatus),
|
||||
default=TransactionStatus.COMPLETED,
|
||||
nullable=False,
|
||||
comment="交易状态",
|
||||
)
|
||||
|
||||
# 金额(整数单位)
|
||||
amount: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
comment="交易金额(单位额度,正数表示收入,负数表示支出)",
|
||||
)
|
||||
balance_before: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
comment="交易前余额(单位额度)",
|
||||
)
|
||||
balance_after: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
comment="交易后余额(单位额度)",
|
||||
)
|
||||
|
||||
# 业务关联
|
||||
reference_type: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="关联业务类型(如 redeem_code、api_call)",
|
||||
)
|
||||
reference_id: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="关联业务 ID",
|
||||
)
|
||||
|
||||
# 描述
|
||||
description: Mapped[str | None] = mapped_column(
|
||||
String(255),
|
||||
nullable=True,
|
||||
comment="交易描述",
|
||||
)
|
||||
remark: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="备注(内部使用)",
|
||||
)
|
||||
|
||||
# 操作人(管理员调整时记录)
|
||||
operator_id: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
comment="操作人 ID(管理员调整时)",
|
||||
)
|
||||
|
||||
# 幂等键(防止重复提交)
|
||||
idempotency_key: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
unique=True,
|
||||
nullable=True,
|
||||
comment="幂等键",
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="创建时间",
|
||||
)
|
||||
|
||||
# 关系
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
foreign_keys=[user_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
balance_account: Mapped["UserBalance"] = relationship(
|
||||
"UserBalance",
|
||||
back_populates="transactions",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
@property
|
||||
def display_amount(self) -> str:
|
||||
"""显示金额(带符号,2 位小数)"""
|
||||
return f"{self.amount / 1000:+.2f}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<BalanceTransaction(id={self.id!r}, "
|
||||
f"type={self.transaction_type.value}, "
|
||||
f"amount={self.display_amount})>"
|
||||
)
|
||||
|
||||
409
app/models/redeem_code.py
Normal file
409
app/models/redeem_code.py
Normal file
@@ -0,0 +1,409 @@
|
||||
"""
|
||||
兑换码模型
|
||||
|
||||
定义余额兑换码相关的数据表结构。
|
||||
|
||||
设计说明:
|
||||
- 兑换码支持批量生成、导入导出
|
||||
- 记录完整的使用日志
|
||||
- 支持设置有效期和使用限制
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Enum as SQLEnum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""生成 UUID 字符串"""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""获取当前 UTC 时间"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def generate_redeem_code(length: int = 16) -> str:
|
||||
"""
|
||||
生成兑换码
|
||||
|
||||
格式:大写字母和数字,分段显示(如 XXXX-XXXX-XXXX-XXXX)
|
||||
排除容易混淆的字符:0/O, 1/I/L
|
||||
"""
|
||||
# 可用字符集(排除易混淆字符)
|
||||
alphabet = "ABCDEFGHJKMNPQRSTUVWXYZ23456789"
|
||||
code = "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
# 每 4 个字符用连字符分隔
|
||||
return "-".join(code[i:i + 4] for i in range(0, len(code), 4))
|
||||
|
||||
|
||||
class RedeemCodeStatus(str, Enum):
|
||||
"""兑换码状态枚举"""
|
||||
|
||||
ACTIVE = "active" # 可用
|
||||
USED = "used" # 已使用
|
||||
DISABLED = "disabled" # 已禁用
|
||||
EXPIRED = "expired" # 已过期
|
||||
|
||||
|
||||
class RedeemCodeBatch(Base):
|
||||
"""
|
||||
兑换码批次模型
|
||||
|
||||
用于管理批量生成的兑换码
|
||||
"""
|
||||
|
||||
__tablename__ = "redeem_code_batches"
|
||||
|
||||
# 主键
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=generate_uuid,
|
||||
comment="批次唯一标识",
|
||||
)
|
||||
|
||||
# 批次信息
|
||||
name: Mapped[str] = mapped_column(
|
||||
String(128),
|
||||
nullable=False,
|
||||
comment="批次名称",
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="批次描述",
|
||||
)
|
||||
|
||||
# 面值(单位额度)
|
||||
face_value: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
comment="面值(单位额度,1000 = 1.00)",
|
||||
)
|
||||
|
||||
# 数量统计
|
||||
total_count: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="生成总数",
|
||||
)
|
||||
used_count: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="已使用数量",
|
||||
)
|
||||
|
||||
# 创建者
|
||||
created_by: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
comment="创建者 ID",
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="创建时间",
|
||||
)
|
||||
|
||||
# 关系
|
||||
codes: Mapped[list["RedeemCode"]] = relationship(
|
||||
"RedeemCode",
|
||||
back_populates="batch",
|
||||
lazy="selectin",
|
||||
)
|
||||
creator: Mapped["User"] = relationship(
|
||||
"User",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
@property
|
||||
def display_face_value(self) -> str:
|
||||
"""显示面值(2 位小数)"""
|
||||
return f"{self.face_value / 1000:.2f}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<RedeemCodeBatch(id={self.id!r}, name={self.name!r})>"
|
||||
|
||||
|
||||
class RedeemCode(Base):
|
||||
"""
|
||||
兑换码模型
|
||||
|
||||
单个兑换码记录
|
||||
"""
|
||||
|
||||
__tablename__ = "redeem_codes"
|
||||
__table_args__ = (
|
||||
Index("ix_redeem_codes_status_expires", "status", "expires_at"),
|
||||
)
|
||||
|
||||
# 主键
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=generate_uuid,
|
||||
comment="兑换码记录唯一标识",
|
||||
)
|
||||
|
||||
# 兑换码(唯一索引)
|
||||
code: Mapped[str] = mapped_column(
|
||||
String(32),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
default=generate_redeem_code,
|
||||
comment="兑换码",
|
||||
)
|
||||
|
||||
# 批次关联(可选)
|
||||
batch_id: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("redeem_code_batches.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="关联批次 ID",
|
||||
)
|
||||
|
||||
# 面值(单位额度)
|
||||
face_value: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
comment="面值(单位额度,1000 = 1.00)",
|
||||
)
|
||||
|
||||
# 状态
|
||||
status: Mapped[RedeemCodeStatus] = mapped_column(
|
||||
SQLEnum(RedeemCodeStatus),
|
||||
default=RedeemCodeStatus.ACTIVE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="兑换码状态",
|
||||
)
|
||||
|
||||
# 使用限制
|
||||
max_uses: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
default=1,
|
||||
nullable=False,
|
||||
comment="最大使用次数",
|
||||
)
|
||||
used_count: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
default=0,
|
||||
nullable=False,
|
||||
comment="已使用次数",
|
||||
)
|
||||
|
||||
# 有效期
|
||||
expires_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="过期时间",
|
||||
)
|
||||
|
||||
# 使用信息
|
||||
used_by: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
comment="使用者 ID(最后使用)",
|
||||
)
|
||||
used_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="使用时间(最后使用)",
|
||||
)
|
||||
|
||||
# 备注
|
||||
remark: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="备注",
|
||||
)
|
||||
|
||||
# 创建者
|
||||
created_by: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
comment="创建者 ID",
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="创建时间",
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
onupdate=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="更新时间",
|
||||
)
|
||||
|
||||
# 关系
|
||||
batch: Mapped["RedeemCodeBatch | None"] = relationship(
|
||||
"RedeemCodeBatch",
|
||||
back_populates="codes",
|
||||
lazy="selectin",
|
||||
)
|
||||
usage_logs: Mapped[list["RedeemCodeUsageLog"]] = relationship(
|
||||
"RedeemCodeUsageLog",
|
||||
back_populates="redeem_code",
|
||||
lazy="selectin",
|
||||
order_by="desc(RedeemCodeUsageLog.created_at)",
|
||||
)
|
||||
|
||||
@property
|
||||
def display_face_value(self) -> str:
|
||||
"""显示面值(2 位小数)"""
|
||||
return f"{self.face_value / 1000:.2f}"
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""检查兑换码是否有效"""
|
||||
if self.status != RedeemCodeStatus.ACTIVE:
|
||||
return False
|
||||
if self.used_count >= self.max_uses:
|
||||
return False
|
||||
if self.expires_at and self.expires_at < utc_now():
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<RedeemCode(code={self.code!r}, status={self.status.value})>"
|
||||
|
||||
|
||||
class RedeemCodeUsageLog(Base):
|
||||
"""
|
||||
兑换码使用日志模型
|
||||
|
||||
记录每次兑换的详细信息
|
||||
"""
|
||||
|
||||
__tablename__ = "redeem_code_usage_logs"
|
||||
__table_args__ = (
|
||||
Index("ix_redeem_usage_user_created", "user_id", "created_at"),
|
||||
)
|
||||
|
||||
# 主键
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=generate_uuid,
|
||||
comment="日志唯一标识",
|
||||
)
|
||||
|
||||
# 关联
|
||||
redeem_code_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("redeem_codes.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="关联兑换码 ID",
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
comment="使用者 ID",
|
||||
)
|
||||
transaction_id: Mapped[str | None] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("balance_transactions.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
comment="关联交易记录 ID",
|
||||
)
|
||||
|
||||
# 兑换信息快照
|
||||
code_snapshot: Mapped[str] = mapped_column(
|
||||
String(32),
|
||||
nullable=False,
|
||||
comment="兑换码快照",
|
||||
)
|
||||
face_value: Mapped[int] = mapped_column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
comment="面值(单位额度)",
|
||||
)
|
||||
|
||||
# 客户端信息
|
||||
ip_address: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
comment="客户端 IP",
|
||||
)
|
||||
user_agent: Mapped[str | None] = mapped_column(
|
||||
String(512),
|
||||
nullable=True,
|
||||
comment="User Agent",
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="使用时间",
|
||||
)
|
||||
|
||||
# 关系
|
||||
redeem_code: Mapped["RedeemCode"] = relationship(
|
||||
"RedeemCode",
|
||||
back_populates="usage_logs",
|
||||
lazy="selectin",
|
||||
)
|
||||
user: Mapped["User"] = relationship(
|
||||
"User",
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
@property
|
||||
def display_face_value(self) -> str:
|
||||
"""显示面值(2 位小数)"""
|
||||
return f"{self.face_value / 1000:.2f}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<RedeemCodeUsageLog(code={self.code_snapshot!r}, "
|
||||
f"user_id={self.user_id!r})>"
|
||||
)
|
||||
|
||||
141
app/models/user.py
Normal file
141
app/models/user.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
用户模型
|
||||
|
||||
定义用户数据表结构。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, String, Text, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.models.balance import UserBalance
|
||||
|
||||
|
||||
def generate_uuid() -> str:
|
||||
"""生成 UUID 字符串"""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
"""获取当前 UTC 时间"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""用户模型"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
# 主键:使用 UUID 字符串
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
primary_key=True,
|
||||
default=generate_uuid,
|
||||
comment="用户唯一标识",
|
||||
)
|
||||
|
||||
# 账户信息
|
||||
username: Mapped[str] = mapped_column(
|
||||
String(32),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=False,
|
||||
comment="用户名",
|
||||
)
|
||||
email: Mapped[str | None] = mapped_column(
|
||||
String(255),
|
||||
unique=True,
|
||||
index=True,
|
||||
nullable=True,
|
||||
comment="邮箱地址",
|
||||
)
|
||||
hashed_password: Mapped[str | None] = mapped_column(
|
||||
String(255),
|
||||
nullable=True, # OAuth2 用户可能没有密码
|
||||
comment="密码哈希",
|
||||
)
|
||||
|
||||
# OAuth2 关联信息
|
||||
oauth_provider: Mapped[str | None] = mapped_column(
|
||||
String(32),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="OAuth2 提供商(如 linuxdo)",
|
||||
)
|
||||
oauth_user_id: Mapped[str | None] = mapped_column(
|
||||
String(128),
|
||||
nullable=True,
|
||||
index=True,
|
||||
comment="OAuth2 用户 ID",
|
||||
)
|
||||
|
||||
# 用户状态
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
comment="是否激活",
|
||||
)
|
||||
is_superuser: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=False,
|
||||
nullable=False,
|
||||
comment="是否为超级管理员",
|
||||
)
|
||||
|
||||
# 个人信息
|
||||
nickname: Mapped[str | None] = mapped_column(
|
||||
String(64),
|
||||
nullable=True,
|
||||
comment="昵称",
|
||||
)
|
||||
avatar_url: Mapped[str | None] = mapped_column(
|
||||
String(512),
|
||||
nullable=True,
|
||||
comment="头像 URL",
|
||||
)
|
||||
bio: Mapped[str | None] = mapped_column(
|
||||
Text,
|
||||
nullable=True,
|
||||
comment="个人简介",
|
||||
)
|
||||
|
||||
# 时间戳
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="创建时间",
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
default=utc_now,
|
||||
onupdate=utc_now,
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
comment="更新时间",
|
||||
)
|
||||
last_login_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
comment="最后登录时间",
|
||||
)
|
||||
|
||||
# 关系
|
||||
balance_account: Mapped["UserBalance | None"] = relationship(
|
||||
"UserBalance",
|
||||
back_populates="user",
|
||||
uselist=False,
|
||||
lazy="selectin",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User(id={self.id!r}, username={self.username!r})>"
|
||||
|
||||
19
app/repositories/__init__.py
Normal file
19
app/repositories/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""数据仓库层"""
|
||||
|
||||
from app.repositories.user import UserRepository
|
||||
from app.repositories.balance import BalanceRepository, TransactionRepository
|
||||
from app.repositories.redeem_code import (
|
||||
RedeemCodeRepository,
|
||||
RedeemCodeBatchRepository,
|
||||
RedeemCodeUsageLogRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"UserRepository",
|
||||
"BalanceRepository",
|
||||
"TransactionRepository",
|
||||
"RedeemCodeRepository",
|
||||
"RedeemCodeBatchRepository",
|
||||
"RedeemCodeUsageLogRepository",
|
||||
]
|
||||
|
||||
378
app/repositories/balance.py
Normal file
378
app/repositories/balance.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
余额仓库
|
||||
|
||||
处理余额相关的数据库操作。
|
||||
|
||||
设计说明:
|
||||
- 使用乐观锁(version)防止并发更新冲突
|
||||
- 提供行级锁支持(悲观锁)用于关键操作
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.balance import (
|
||||
UserBalance,
|
||||
BalanceTransaction,
|
||||
TransactionType,
|
||||
TransactionStatus,
|
||||
)
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class BalanceRepository(BaseRepository[UserBalance]):
|
||||
"""余额仓库"""
|
||||
|
||||
model = UserBalance
|
||||
|
||||
async def get_by_user_id(self, user_id: str) -> UserBalance | None:
|
||||
"""
|
||||
通过用户 ID 获取余额账户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户或 None
|
||||
"""
|
||||
stmt = select(UserBalance).where(UserBalance.user_id == user_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_user_id_for_update(self, user_id: str) -> UserBalance | None:
|
||||
"""
|
||||
通过用户 ID 获取余额账户(加行级锁)
|
||||
|
||||
用于需要原子性更新的场景,如扣款操作。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户或 None
|
||||
"""
|
||||
stmt = (
|
||||
select(UserBalance)
|
||||
.where(UserBalance.user_id == user_id)
|
||||
.with_for_update() # 行级锁
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_or_create(self, user_id: str) -> UserBalance:
|
||||
"""
|
||||
获取或创建余额账户
|
||||
|
||||
如果用户没有余额账户,自动创建一个。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户
|
||||
"""
|
||||
balance = await self.get_by_user_id(user_id)
|
||||
if balance is None:
|
||||
balance = await self.create(user_id=user_id)
|
||||
return balance
|
||||
|
||||
async def get_or_create_for_update(self, user_id: str) -> UserBalance:
|
||||
"""
|
||||
获取或创建余额账户(加行级锁)
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户
|
||||
"""
|
||||
balance = await self.get_by_user_id_for_update(user_id)
|
||||
if balance is None:
|
||||
balance = await self.create(user_id=user_id)
|
||||
# 重新获取并加锁
|
||||
await self.session.flush()
|
||||
balance = await self.get_by_user_id_for_update(user_id)
|
||||
return balance # type: ignore
|
||||
|
||||
async def update_balance_optimistic(
|
||||
self,
|
||||
balance: UserBalance,
|
||||
delta: int,
|
||||
*,
|
||||
is_recharge: bool = False,
|
||||
is_consumption: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
使用乐观锁更新余额
|
||||
|
||||
通过版本号检查确保并发安全。
|
||||
|
||||
Args:
|
||||
balance: 余额账户
|
||||
delta: 变化量(正数增加,负数减少)
|
||||
is_recharge: 是否为充值
|
||||
is_consumption: 是否为消费
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
current_version = balance.version
|
||||
new_balance = balance.balance + delta
|
||||
|
||||
# 构建更新语句
|
||||
update_values: dict[str, Any] = {
|
||||
"balance": new_balance,
|
||||
"version": current_version + 1,
|
||||
}
|
||||
|
||||
if is_recharge and delta > 0:
|
||||
update_values["total_recharged"] = balance.total_recharged + delta
|
||||
if is_consumption and delta < 0:
|
||||
update_values["total_consumed"] = balance.total_consumed + abs(delta)
|
||||
|
||||
stmt = (
|
||||
update(UserBalance)
|
||||
.where(
|
||||
and_(
|
||||
UserBalance.id == balance.id,
|
||||
UserBalance.version == current_version, # 乐观锁检查
|
||||
)
|
||||
)
|
||||
.values(**update_values)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
# 更新成功,刷新对象
|
||||
balance.balance = new_balance
|
||||
balance.version = current_version + 1
|
||||
if is_recharge and delta > 0:
|
||||
balance.total_recharged += delta
|
||||
if is_consumption and delta < 0:
|
||||
balance.total_consumed += abs(delta)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def freeze_balance(
|
||||
self,
|
||||
balance: UserBalance,
|
||||
amount: int,
|
||||
) -> bool:
|
||||
"""
|
||||
冻结余额
|
||||
|
||||
Args:
|
||||
balance: 余额账户
|
||||
amount: 冻结金额(正数)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if amount <= 0:
|
||||
return False
|
||||
if balance.available_balance < amount:
|
||||
return False
|
||||
|
||||
current_version = balance.version
|
||||
stmt = (
|
||||
update(UserBalance)
|
||||
.where(
|
||||
and_(
|
||||
UserBalance.id == balance.id,
|
||||
UserBalance.version == current_version,
|
||||
UserBalance.balance - UserBalance.frozen_balance >= amount,
|
||||
)
|
||||
)
|
||||
.values(
|
||||
frozen_balance=UserBalance.frozen_balance + amount,
|
||||
version=current_version + 1,
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
balance.frozen_balance += amount
|
||||
balance.version = current_version + 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def unfreeze_balance(
|
||||
self,
|
||||
balance: UserBalance,
|
||||
amount: int,
|
||||
) -> bool:
|
||||
"""
|
||||
解冻余额
|
||||
|
||||
Args:
|
||||
balance: 余额账户
|
||||
amount: 解冻金额(正数)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if amount <= 0:
|
||||
return False
|
||||
if balance.frozen_balance < amount:
|
||||
return False
|
||||
|
||||
current_version = balance.version
|
||||
stmt = (
|
||||
update(UserBalance)
|
||||
.where(
|
||||
and_(
|
||||
UserBalance.id == balance.id,
|
||||
UserBalance.version == current_version,
|
||||
UserBalance.frozen_balance >= amount,
|
||||
)
|
||||
)
|
||||
.values(
|
||||
frozen_balance=UserBalance.frozen_balance - amount,
|
||||
version=current_version + 1,
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
balance.frozen_balance -= amount
|
||||
balance.version = current_version + 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TransactionRepository(BaseRepository[BalanceTransaction]):
|
||||
"""交易记录仓库"""
|
||||
|
||||
model = BalanceTransaction
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
transaction_type: TransactionType | None = None,
|
||||
status: TransactionStatus | None = None,
|
||||
) -> list[BalanceTransaction]:
|
||||
"""
|
||||
获取用户的交易记录
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
transaction_type: 交易类型过滤
|
||||
status: 状态过滤
|
||||
|
||||
Returns:
|
||||
交易记录列表
|
||||
"""
|
||||
stmt = select(BalanceTransaction).where(
|
||||
BalanceTransaction.user_id == user_id
|
||||
)
|
||||
|
||||
if transaction_type:
|
||||
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
|
||||
if status:
|
||||
stmt = stmt.where(BalanceTransaction.status == status)
|
||||
|
||||
stmt = (
|
||||
stmt
|
||||
.order_by(BalanceTransaction.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
transaction_type: TransactionType | None = None,
|
||||
status: TransactionStatus | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计用户的交易记录数量
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
transaction_type: 交易类型过滤
|
||||
status: 状态过滤
|
||||
|
||||
Returns:
|
||||
记录数量
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
stmt = select(func.count()).select_from(BalanceTransaction).where(
|
||||
BalanceTransaction.user_id == user_id
|
||||
)
|
||||
|
||||
if transaction_type:
|
||||
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
|
||||
if status:
|
||||
stmt = stmt.where(BalanceTransaction.status == status)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_by_idempotency_key(
|
||||
self,
|
||||
idempotency_key: str,
|
||||
) -> BalanceTransaction | None:
|
||||
"""
|
||||
通过幂等键获取交易记录
|
||||
|
||||
用于防止重复提交。
|
||||
|
||||
Args:
|
||||
idempotency_key: 幂等键
|
||||
|
||||
Returns:
|
||||
交易记录或 None
|
||||
"""
|
||||
stmt = select(BalanceTransaction).where(
|
||||
BalanceTransaction.idempotency_key == idempotency_key
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_reference(
|
||||
self,
|
||||
reference_type: str,
|
||||
reference_id: str,
|
||||
) -> list[BalanceTransaction]:
|
||||
"""
|
||||
通过业务关联获取交易记录
|
||||
|
||||
Args:
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
|
||||
Returns:
|
||||
交易记录列表
|
||||
"""
|
||||
stmt = (
|
||||
select(BalanceTransaction)
|
||||
.where(
|
||||
and_(
|
||||
BalanceTransaction.reference_type == reference_type,
|
||||
BalanceTransaction.reference_id == reference_id,
|
||||
)
|
||||
)
|
||||
.order_by(BalanceTransaction.created_at.desc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
138
app/repositories/base.py
Normal file
138
app/repositories/base.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
基础仓库类
|
||||
|
||||
提供通用的 CRUD 操作封装。
|
||||
"""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import Base
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=Base)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelT]):
|
||||
"""
|
||||
基础仓库类
|
||||
|
||||
提供通用的数据库操作方法。
|
||||
"""
|
||||
|
||||
model: type[ModelT]
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化仓库
|
||||
|
||||
Args:
|
||||
session: 异步数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: str) -> ModelT | None:
|
||||
"""
|
||||
通过 ID 获取实体
|
||||
|
||||
Args:
|
||||
id: 实体 ID
|
||||
|
||||
Returns:
|
||||
实体对象或 None
|
||||
"""
|
||||
return await self.session.get(self.model, id)
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> list[ModelT]:
|
||||
"""
|
||||
获取所有实体
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
stmt = select(self.model).offset(offset).limit(limit)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count(self) -> int:
|
||||
"""
|
||||
获取实体总数
|
||||
|
||||
Returns:
|
||||
实体数量
|
||||
"""
|
||||
stmt = select(func.count()).select_from(self.model)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def create(self, **kwargs: Any) -> ModelT:
|
||||
"""
|
||||
创建新实体
|
||||
|
||||
Args:
|
||||
**kwargs: 实体属性
|
||||
|
||||
Returns:
|
||||
新创建的实体
|
||||
"""
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = str(uuid4())
|
||||
|
||||
entity = self.model(**kwargs)
|
||||
self.session.add(entity)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(entity)
|
||||
return entity
|
||||
|
||||
async def update(
|
||||
self,
|
||||
entity: ModelT,
|
||||
**kwargs: Any,
|
||||
) -> ModelT:
|
||||
"""
|
||||
更新实体
|
||||
|
||||
Args:
|
||||
entity: 要更新的实体
|
||||
**kwargs: 要更新的属性
|
||||
|
||||
Returns:
|
||||
更新后的实体
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(entity, key):
|
||||
setattr(entity, key, value)
|
||||
|
||||
await self.session.flush()
|
||||
await self.session.refresh(entity)
|
||||
return entity
|
||||
|
||||
async def delete(self, entity: ModelT) -> None:
|
||||
"""
|
||||
删除实体
|
||||
|
||||
Args:
|
||||
entity: 要删除的实体
|
||||
"""
|
||||
await self.session.delete(entity)
|
||||
await self.session.flush()
|
||||
|
||||
async def commit(self) -> None:
|
||||
"""提交事务"""
|
||||
await self.session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""回滚事务"""
|
||||
await self.session.rollback()
|
||||
|
||||
462
app/repositories/redeem_code.py
Normal file
462
app/repositories/redeem_code.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
兑换码仓库
|
||||
|
||||
处理兑换码相关的数据库操作。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update, and_, or_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.redeem_code import (
|
||||
RedeemCode,
|
||||
RedeemCodeBatch,
|
||||
RedeemCodeUsageLog,
|
||||
RedeemCodeStatus,
|
||||
)
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class RedeemCodeRepository(BaseRepository[RedeemCode]):
|
||||
"""兑换码仓库"""
|
||||
|
||||
model = RedeemCode
|
||||
|
||||
async def get_by_code(self, code: str) -> RedeemCode | None:
|
||||
"""
|
||||
通过兑换码获取记录
|
||||
|
||||
Args:
|
||||
code: 兑换码
|
||||
|
||||
Returns:
|
||||
兑换码记录或 None
|
||||
"""
|
||||
# 标准化兑换码格式
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
stmt = select(RedeemCode).where(RedeemCode.code == normalized_code)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_code_for_update(self, code: str) -> RedeemCode | None:
|
||||
"""
|
||||
通过兑换码获取记录(加行级锁)
|
||||
|
||||
用于兑换操作,防止并发兑换。
|
||||
|
||||
Args:
|
||||
code: 兑换码
|
||||
|
||||
Returns:
|
||||
兑换码记录或 None
|
||||
"""
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
stmt = (
|
||||
select(RedeemCode)
|
||||
.where(RedeemCode.code == normalized_code)
|
||||
.with_for_update()
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_valid_code(self, code: str) -> RedeemCode | None:
|
||||
"""
|
||||
获取有效的兑换码
|
||||
|
||||
检查状态、使用次数和有效期。
|
||||
|
||||
Args:
|
||||
code: 兑换码
|
||||
|
||||
Returns:
|
||||
有效的兑换码或 None
|
||||
"""
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
stmt = (
|
||||
select(RedeemCode)
|
||||
.where(
|
||||
and_(
|
||||
RedeemCode.code == normalized_code,
|
||||
RedeemCode.status == RedeemCodeStatus.ACTIVE,
|
||||
RedeemCode.used_count < RedeemCode.max_uses,
|
||||
or_(
|
||||
RedeemCode.expires_at.is_(None),
|
||||
RedeemCode.expires_at > now,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def mark_as_used(
|
||||
self,
|
||||
code: RedeemCode,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
标记兑换码已使用
|
||||
|
||||
更新使用计数和状态。
|
||||
|
||||
Args:
|
||||
code: 兑换码记录
|
||||
user_id: 使用者 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
new_used_count = code.used_count + 1
|
||||
new_status = (
|
||||
RedeemCodeStatus.USED
|
||||
if new_used_count >= code.max_uses
|
||||
else RedeemCodeStatus.ACTIVE
|
||||
)
|
||||
|
||||
stmt = (
|
||||
update(RedeemCode)
|
||||
.where(
|
||||
and_(
|
||||
RedeemCode.id == code.id,
|
||||
RedeemCode.used_count == code.used_count, # 乐观锁
|
||||
)
|
||||
)
|
||||
.values(
|
||||
used_count=new_used_count,
|
||||
status=new_status,
|
||||
used_by=user_id,
|
||||
used_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
code.used_count = new_used_count
|
||||
code.status = new_status
|
||||
code.used_by = user_id
|
||||
code.used_at = now
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def get_all_with_filters(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
batch_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> list[RedeemCode]:
|
||||
"""
|
||||
获取兑换码列表(支持过滤)
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
status: 状态过滤
|
||||
batch_id: 批次 ID 过滤
|
||||
code_like: 兑换码模糊匹配
|
||||
created_after: 创建时间起始
|
||||
created_before: 创建时间结束
|
||||
|
||||
Returns:
|
||||
兑换码列表
|
||||
"""
|
||||
stmt = select(RedeemCode)
|
||||
|
||||
conditions = []
|
||||
if status:
|
||||
conditions.append(RedeemCode.status == status)
|
||||
if batch_id:
|
||||
conditions.append(RedeemCode.batch_id == batch_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCode.code.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCode.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCode.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
stmt = (
|
||||
stmt
|
||||
.order_by(RedeemCode.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_with_filters(
|
||||
self,
|
||||
*,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
batch_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计兑换码数量(支持过滤)
|
||||
"""
|
||||
stmt = select(func.count()).select_from(RedeemCode)
|
||||
|
||||
conditions = []
|
||||
if status:
|
||||
conditions.append(RedeemCode.status == status)
|
||||
if batch_id:
|
||||
conditions.append(RedeemCode.batch_id == batch_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCode.code.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCode.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCode.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def bulk_create(
|
||||
self,
|
||||
codes_data: list[dict[str, Any]],
|
||||
) -> list[RedeemCode]:
|
||||
"""
|
||||
批量创建兑换码
|
||||
|
||||
Args:
|
||||
codes_data: 兑换码数据列表
|
||||
|
||||
Returns:
|
||||
创建的兑换码列表
|
||||
"""
|
||||
codes = [RedeemCode(**data) for data in codes_data]
|
||||
self.session.add_all(codes)
|
||||
await self.session.flush()
|
||||
return codes
|
||||
|
||||
async def disable_code(self, code: RedeemCode) -> RedeemCode:
|
||||
"""
|
||||
禁用兑换码
|
||||
|
||||
Args:
|
||||
code: 兑换码记录
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
return await self.update(code, status=RedeemCodeStatus.DISABLED)
|
||||
|
||||
async def enable_code(self, code: RedeemCode) -> RedeemCode:
|
||||
"""
|
||||
启用兑换码
|
||||
|
||||
Args:
|
||||
code: 兑换码记录
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
# 只有禁用状态的可以重新启用
|
||||
if code.status != RedeemCodeStatus.DISABLED:
|
||||
return code
|
||||
|
||||
# 如果使用次数已满,改为已使用状态
|
||||
if code.used_count >= code.max_uses:
|
||||
return await self.update(code, status=RedeemCodeStatus.USED)
|
||||
|
||||
return await self.update(code, status=RedeemCodeStatus.ACTIVE)
|
||||
|
||||
|
||||
class RedeemCodeBatchRepository(BaseRepository[RedeemCodeBatch]):
|
||||
"""兑换码批次仓库"""
|
||||
|
||||
model = RedeemCodeBatch
|
||||
|
||||
async def get_all_batches(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> list[RedeemCodeBatch]:
|
||||
"""
|
||||
获取所有批次
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
批次列表
|
||||
"""
|
||||
stmt = (
|
||||
select(RedeemCodeBatch)
|
||||
.order_by(RedeemCodeBatch.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def increment_used_count(self, batch_id: str) -> None:
|
||||
"""
|
||||
增加批次已使用计数
|
||||
|
||||
Args:
|
||||
batch_id: 批次 ID
|
||||
"""
|
||||
stmt = (
|
||||
update(RedeemCodeBatch)
|
||||
.where(RedeemCodeBatch.id == batch_id)
|
||||
.values(used_count=RedeemCodeBatch.used_count + 1)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
|
||||
class RedeemCodeUsageLogRepository(BaseRepository[RedeemCodeUsageLog]):
|
||||
"""兑换码使用日志仓库"""
|
||||
|
||||
model = RedeemCodeUsageLog
|
||||
|
||||
async def get_by_code_id(
|
||||
self,
|
||||
redeem_code_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> list[RedeemCodeUsageLog]:
|
||||
"""
|
||||
获取兑换码的使用日志
|
||||
|
||||
Args:
|
||||
redeem_code_id: 兑换码 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
使用日志列表
|
||||
"""
|
||||
stmt = (
|
||||
select(RedeemCodeUsageLog)
|
||||
.where(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
||||
.order_by(RedeemCodeUsageLog.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> list[RedeemCodeUsageLog]:
|
||||
"""
|
||||
获取用户的兑换日志
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
使用日志列表
|
||||
"""
|
||||
stmt = (
|
||||
select(RedeemCodeUsageLog)
|
||||
.where(RedeemCodeUsageLog.user_id == user_id)
|
||||
.order_by(RedeemCodeUsageLog.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_all_with_filters(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
redeem_code_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> list[RedeemCodeUsageLog]:
|
||||
"""
|
||||
获取使用日志(支持过滤)
|
||||
"""
|
||||
stmt = select(RedeemCodeUsageLog)
|
||||
|
||||
conditions = []
|
||||
if redeem_code_id:
|
||||
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
||||
if user_id:
|
||||
conditions.append(RedeemCodeUsageLog.user_id == user_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
stmt = (
|
||||
stmt
|
||||
.order_by(RedeemCodeUsageLog.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_with_filters(
|
||||
self,
|
||||
*,
|
||||
redeem_code_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计使用日志数量
|
||||
"""
|
||||
stmt = select(func.count()).select_from(RedeemCodeUsageLog)
|
||||
|
||||
conditions = []
|
||||
if redeem_code_id:
|
||||
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
||||
if user_id:
|
||||
conditions.append(RedeemCodeUsageLog.user_id == user_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
141
app/repositories/user.py
Normal file
141
app/repositories/user.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
用户仓库
|
||||
|
||||
处理用户相关的数据库操作。
|
||||
"""
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User]):
|
||||
"""用户数据仓库"""
|
||||
|
||||
model = User
|
||||
|
||||
async def get_by_username(self, username: str) -> User | None:
|
||||
"""
|
||||
通过用户名获取用户
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
stmt = select(User).where(User.username == username.lower())
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_email(self, email: str) -> User | None:
|
||||
"""
|
||||
通过邮箱获取用户
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
stmt = select(User).where(User.email == email.lower())
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_username_or_email(self, identifier: str) -> User | None:
|
||||
"""
|
||||
通过用户名或邮箱获取用户
|
||||
|
||||
Args:
|
||||
identifier: 用户名或邮箱
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
identifier_lower = identifier.lower()
|
||||
stmt = select(User).where(
|
||||
or_(
|
||||
User.username == identifier_lower,
|
||||
User.email == identifier_lower,
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def exists_by_username(self, username: str) -> bool:
|
||||
"""
|
||||
检查用户名是否存在
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
user = await self.get_by_username(username)
|
||||
return user is not None
|
||||
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
"""
|
||||
检查邮箱是否存在
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
user = await self.get_by_email(email)
|
||||
return user is not None
|
||||
|
||||
async def get_by_oauth(
|
||||
self,
|
||||
provider: str,
|
||||
oauth_user_id: str,
|
||||
) -> User | None:
|
||||
"""
|
||||
通过 OAuth2 提供商和用户 ID 获取用户
|
||||
|
||||
Args:
|
||||
provider: OAuth2 提供商标识
|
||||
oauth_user_id: OAuth2 用户 ID
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
stmt = select(User).where(
|
||||
User.oauth_provider == provider,
|
||||
User.oauth_user_id == oauth_user_id,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_users(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> list[User]:
|
||||
"""
|
||||
获取活跃用户列表
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
活跃用户列表
|
||||
"""
|
||||
stmt = (
|
||||
select(User)
|
||||
.where(User.is_active == True) # noqa: E712
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(User.created_at.desc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
75
app/schemas/__init__.py
Normal file
75
app/schemas/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Pydantic 数据模式"""
|
||||
|
||||
from app.schemas.auth import (
|
||||
LoginRequest,
|
||||
PasswordChangeRequest,
|
||||
RefreshTokenRequest,
|
||||
TokenResponse,
|
||||
)
|
||||
from app.schemas.oauth2 import (
|
||||
OAuth2AuthorizeResponse,
|
||||
OAuth2CallbackRequest,
|
||||
OAuth2LoginResponse,
|
||||
OAuth2TokenData,
|
||||
OAuth2UserInfo,
|
||||
)
|
||||
from app.schemas.user import (
|
||||
UserCreate,
|
||||
UserResponse,
|
||||
UserUpdate,
|
||||
)
|
||||
from app.schemas.balance import (
|
||||
BalanceResponse,
|
||||
TransactionResponse,
|
||||
DeductionRequest,
|
||||
DeductionResponse,
|
||||
AdminAdjustmentRequest,
|
||||
AdminBalanceResponse,
|
||||
)
|
||||
from app.schemas.redeem_code import (
|
||||
RedeemRequest,
|
||||
RedeemResponse,
|
||||
RedeemCodeResponse,
|
||||
BatchCreateRequest,
|
||||
BatchResponse,
|
||||
BulkImportRequest,
|
||||
BulkImportResponse,
|
||||
ExportResponse,
|
||||
UsageLogResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Auth
|
||||
"LoginRequest",
|
||||
"TokenResponse",
|
||||
"RefreshTokenRequest",
|
||||
"PasswordChangeRequest",
|
||||
# OAuth2
|
||||
"OAuth2AuthorizeResponse",
|
||||
"OAuth2CallbackRequest",
|
||||
"OAuth2LoginResponse",
|
||||
"OAuth2TokenData",
|
||||
"OAuth2UserInfo",
|
||||
# User
|
||||
"UserCreate",
|
||||
"UserResponse",
|
||||
"UserUpdate",
|
||||
# Balance
|
||||
"BalanceResponse",
|
||||
"TransactionResponse",
|
||||
"DeductionRequest",
|
||||
"DeductionResponse",
|
||||
"AdminAdjustmentRequest",
|
||||
"AdminBalanceResponse",
|
||||
# Redeem Code
|
||||
"RedeemRequest",
|
||||
"RedeemResponse",
|
||||
"RedeemCodeResponse",
|
||||
"BatchCreateRequest",
|
||||
"BatchResponse",
|
||||
"BulkImportRequest",
|
||||
"BulkImportResponse",
|
||||
"ExportResponse",
|
||||
"UsageLogResponse",
|
||||
]
|
||||
|
||||
112
app/schemas/auth.py
Normal file
112
app/schemas/auth.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
认证相关 Schema
|
||||
|
||||
定义登录、令牌等数据结构。
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.schemas.base import BaseSchema
|
||||
|
||||
|
||||
class LoginRequest(BaseSchema):
|
||||
"""登录请求"""
|
||||
|
||||
username: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="用户名或邮箱",
|
||||
examples=["john_doe"],
|
||||
),
|
||||
]
|
||||
password: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="密码",
|
||||
examples=["SecurePass123"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TokenResponse(BaseSchema):
|
||||
"""令牌响应"""
|
||||
|
||||
access_token: str = Field(description="访问令牌")
|
||||
refresh_token: str = Field(description="刷新令牌")
|
||||
token_type: str = Field(default="Bearer", description="令牌类型")
|
||||
expires_in: int = Field(description="访问令牌过期时间(秒)")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseSchema):
|
||||
"""刷新令牌请求"""
|
||||
|
||||
refresh_token: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="刷新令牌",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class PasswordChangeRequest(BaseSchema):
|
||||
"""修改密码请求"""
|
||||
|
||||
current_password: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="当前密码",
|
||||
),
|
||||
]
|
||||
new_password: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=8,
|
||||
max_length=128,
|
||||
description="新密码",
|
||||
),
|
||||
]
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""验证新旧密码不同"""
|
||||
if self.current_password == self.new_password:
|
||||
raise ValueError("新密码不能与当前密码相同")
|
||||
|
||||
|
||||
class PasswordResetRequest(BaseSchema):
|
||||
"""密码重置请求(忘记密码)"""
|
||||
|
||||
email: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="注册邮箱",
|
||||
examples=["user@example.com"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class PasswordResetConfirm(BaseSchema):
|
||||
"""密码重置确认"""
|
||||
|
||||
token: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="重置令牌",
|
||||
),
|
||||
]
|
||||
new_password: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=8,
|
||||
max_length=128,
|
||||
description="新密码",
|
||||
),
|
||||
]
|
||||
|
||||
285
app/schemas/balance.py
Normal file
285
app/schemas/balance.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
余额相关 Schema
|
||||
|
||||
定义余额数据的验证和序列化规则。
|
||||
|
||||
设计说明:
|
||||
- 内部存储使用整数单位(units),外部显示使用小数(display)
|
||||
- 1.00 显示余额 = 1000 单位额度
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field, field_validator, computed_field
|
||||
|
||||
from app.schemas.base import BaseSchema, PaginatedResponse
|
||||
from app.models.balance import TransactionType, TransactionStatus
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 余额单位转换工具
|
||||
# ============================================================
|
||||
|
||||
UNITS_PER_DISPLAY = 1000 # 1.00 显示余额 = 1000 单位额度
|
||||
|
||||
|
||||
def display_to_units(display_amount: float) -> int:
|
||||
"""将显示金额转换为单位额度"""
|
||||
return int(round(display_amount * UNITS_PER_DISPLAY))
|
||||
|
||||
|
||||
def units_to_display(units: int) -> float:
|
||||
"""将单位额度转换为显示金额"""
|
||||
return units / UNITS_PER_DISPLAY
|
||||
|
||||
|
||||
def format_display(units: int) -> str:
|
||||
"""格式化显示金额(2 位小数)"""
|
||||
return f"{units / UNITS_PER_DISPLAY:.2f}"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 余额账户 Schema
|
||||
# ============================================================
|
||||
|
||||
class BalanceResponse(BaseSchema):
|
||||
"""余额信息响应"""
|
||||
|
||||
user_id: str
|
||||
balance_units: Annotated[
|
||||
int,
|
||||
Field(description="当前余额(单位额度)"),
|
||||
]
|
||||
frozen_units: Annotated[
|
||||
int,
|
||||
Field(description="冻结余额(单位额度)"),
|
||||
]
|
||||
total_recharged_units: Annotated[
|
||||
int,
|
||||
Field(description="累计充值(单位额度)"),
|
||||
]
|
||||
total_consumed_units: Annotated[
|
||||
int,
|
||||
Field(description="累计消费(单位额度)"),
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def balance(self) -> str:
|
||||
"""显示余额(2 位小数)"""
|
||||
return format_display(self.balance_units)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def available_balance(self) -> str:
|
||||
"""显示可用余额(2 位小数)"""
|
||||
return format_display(self.balance_units - self.frozen_units)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def frozen_balance(self) -> str:
|
||||
"""显示冻结余额(2 位小数)"""
|
||||
return format_display(self.frozen_units)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def total_recharged(self) -> str:
|
||||
"""显示累计充值(2 位小数)"""
|
||||
return format_display(self.total_recharged_units)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def total_consumed(self) -> str:
|
||||
"""显示累计消费(2 位小数)"""
|
||||
return format_display(self.total_consumed_units)
|
||||
|
||||
|
||||
class BalanceSummaryResponse(BaseSchema):
|
||||
"""余额简要信息响应(用于嵌入用户信息)"""
|
||||
|
||||
balance: str = Field(description="当前余额")
|
||||
available_balance: str = Field(description="可用余额")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 交易记录 Schema
|
||||
# ============================================================
|
||||
|
||||
class TransactionResponse(BaseSchema):
|
||||
"""交易记录响应"""
|
||||
|
||||
id: str
|
||||
transaction_type: TransactionType
|
||||
status: TransactionStatus
|
||||
amount_units: Annotated[
|
||||
int,
|
||||
Field(description="交易金额(单位额度,正数收入,负数支出)"),
|
||||
]
|
||||
balance_before_units: Annotated[
|
||||
int,
|
||||
Field(description="交易前余额(单位额度)"),
|
||||
]
|
||||
balance_after_units: Annotated[
|
||||
int,
|
||||
Field(description="交易后余额(单位额度)"),
|
||||
]
|
||||
reference_type: str | None
|
||||
reference_id: str | None
|
||||
description: str | None
|
||||
created_at: datetime
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def amount(self) -> str:
|
||||
"""显示交易金额(带符号,2 位小数)"""
|
||||
return f"{self.amount_units / UNITS_PER_DISPLAY:+.2f}"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def balance_before(self) -> str:
|
||||
"""显示交易前余额(2 位小数)"""
|
||||
return format_display(self.balance_before_units)
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def balance_after(self) -> str:
|
||||
"""显示交易后余额(2 位小数)"""
|
||||
return format_display(self.balance_after_units)
|
||||
|
||||
|
||||
class TransactionListResponse(PaginatedResponse[TransactionResponse]):
|
||||
"""交易记录列表响应"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 扣款请求 Schema
|
||||
# ============================================================
|
||||
|
||||
class DeductionRequest(BaseSchema):
|
||||
"""扣款请求"""
|
||||
|
||||
amount: Annotated[
|
||||
float,
|
||||
Field(
|
||||
gt=0,
|
||||
description="扣款金额(显示金额,如 1.00)",
|
||||
examples=[1.00, 0.50],
|
||||
),
|
||||
]
|
||||
reference_type: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=64,
|
||||
description="关联业务类型",
|
||||
examples=["api_call", "service_fee"],
|
||||
),
|
||||
]
|
||||
reference_id: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=64,
|
||||
description="关联业务 ID",
|
||||
),
|
||||
]
|
||||
description: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=255,
|
||||
description="交易描述",
|
||||
),
|
||||
]
|
||||
idempotency_key: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=64,
|
||||
description="幂等键(防止重复扣款)",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("amount")
|
||||
@classmethod
|
||||
def validate_amount(cls, v: float) -> float:
|
||||
"""验证金额精度"""
|
||||
# 最小精度 0.001(即 1 单位额度)
|
||||
if round(v, 3) != v:
|
||||
raise ValueError("金额精度不能超过 3 位小数")
|
||||
return v
|
||||
|
||||
@property
|
||||
def amount_units(self) -> int:
|
||||
"""转换为单位额度"""
|
||||
return display_to_units(self.amount)
|
||||
|
||||
|
||||
class DeductionResponse(BaseSchema):
|
||||
"""扣款响应"""
|
||||
|
||||
transaction_id: str
|
||||
amount: str = Field(description="扣款金额")
|
||||
balance_before: str = Field(description="扣款前余额")
|
||||
balance_after: str = Field(description="扣款后余额")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 管理员操作 Schema
|
||||
# ============================================================
|
||||
|
||||
class AdminAdjustmentRequest(BaseSchema):
|
||||
"""管理员余额调整请求"""
|
||||
|
||||
user_id: Annotated[
|
||||
str,
|
||||
Field(description="目标用户 ID"),
|
||||
]
|
||||
amount: Annotated[
|
||||
float,
|
||||
Field(
|
||||
description="调整金额(正数增加,负数减少)",
|
||||
examples=[10.00, -5.00],
|
||||
),
|
||||
]
|
||||
reason: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
description="调整原因",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("amount")
|
||||
@classmethod
|
||||
def validate_amount(cls, v: float) -> float:
|
||||
"""验证金额精度"""
|
||||
if round(v, 3) != v:
|
||||
raise ValueError("金额精度不能超过 3 位小数")
|
||||
if v == 0:
|
||||
raise ValueError("调整金额不能为 0")
|
||||
return v
|
||||
|
||||
@property
|
||||
def amount_units(self) -> int:
|
||||
"""转换为单位额度"""
|
||||
return display_to_units(self.amount)
|
||||
|
||||
|
||||
class AdminBalanceResponse(BaseSchema):
|
||||
"""管理员查看的余额信息(包含更多细节)"""
|
||||
|
||||
user_id: str
|
||||
username: str
|
||||
balance: str
|
||||
available_balance: str
|
||||
frozen_balance: str
|
||||
total_recharged: str
|
||||
total_consumed: str
|
||||
version: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
102
app/schemas/base.py
Normal file
102
app/schemas/base.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
基础 Schema 定义
|
||||
|
||||
定义通用的响应模式和基础配置。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
DataT = TypeVar("DataT")
|
||||
|
||||
|
||||
class BaseSchema(BaseModel):
|
||||
"""基础 Schema 配置"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True, # 支持从 ORM 对象创建
|
||||
populate_by_name=True, # 支持字段别名
|
||||
str_strip_whitespace=True, # 自动去除字符串首尾空白
|
||||
)
|
||||
|
||||
|
||||
class TimestampMixin(BaseModel):
|
||||
"""时间戳混入"""
|
||||
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class APIResponse(BaseModel, Generic[DataT]):
|
||||
"""统一 API 响应格式"""
|
||||
|
||||
success: bool = True
|
||||
message: str = "操作成功"
|
||||
data: DataT | None = None
|
||||
|
||||
@classmethod
|
||||
def ok(cls, data: DataT | None = None, message: str = "操作成功") -> "APIResponse[DataT]":
|
||||
"""成功响应"""
|
||||
return cls(success=True, message=message, data=data)
|
||||
|
||||
@classmethod
|
||||
def error(cls, message: str = "操作失败", data: DataT | None = None) -> "APIResponse[DataT]":
|
||||
"""错误响应"""
|
||||
return cls(success=False, message=message, data=data)
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""错误响应"""
|
||||
|
||||
success: bool = False
|
||||
message: str
|
||||
code: str
|
||||
details: dict[str, Any] = {}
|
||||
|
||||
|
||||
class PaginationParams(BaseModel):
|
||||
"""分页参数"""
|
||||
|
||||
page: int = 1
|
||||
page_size: int = 20
|
||||
|
||||
@property
|
||||
def offset(self) -> int:
|
||||
"""计算偏移量"""
|
||||
return (self.page - 1) * self.page_size
|
||||
|
||||
@property
|
||||
def limit(self) -> int:
|
||||
"""获取限制数量"""
|
||||
return self.page_size
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[DataT]):
|
||||
"""分页响应"""
|
||||
|
||||
items: list[DataT]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
items: list[DataT],
|
||||
total: int,
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> "PaginatedResponse[DataT]":
|
||||
"""创建分页响应"""
|
||||
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
|
||||
return cls(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
85
app/schemas/oauth2.py
Normal file
85
app/schemas/oauth2.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
OAuth2 相关 Schema
|
||||
|
||||
定义 OAuth2 认证流程的数据结构。
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from app.schemas.base import BaseSchema
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseSchema):
|
||||
"""OAuth2 授权 URL 响应"""
|
||||
|
||||
authorize_url: str = Field(description="OAuth2 授权页面 URL")
|
||||
state: str = Field(description="防 CSRF 状态码")
|
||||
|
||||
|
||||
class OAuth2CallbackRequest(BaseSchema):
|
||||
"""OAuth2 回调请求"""
|
||||
|
||||
code: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="授权码",
|
||||
),
|
||||
]
|
||||
state: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
description="状态码(用于验证请求合法性)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class OAuth2TokenData(BaseSchema):
|
||||
"""OAuth2 令牌数据(从 OAuth 提供商获取)"""
|
||||
|
||||
access_token: str
|
||||
token_type: str = "Bearer"
|
||||
expires_in: int | None = None
|
||||
refresh_token: str | None = None
|
||||
scope: str | None = None
|
||||
|
||||
|
||||
class OAuth2UserInfo(BaseSchema):
|
||||
"""
|
||||
OAuth2 用户信息(从 Linux.do 获取)
|
||||
|
||||
示例响应:
|
||||
{
|
||||
"id": 1,
|
||||
"username": "neo",
|
||||
"name": "Neo",
|
||||
"active": true,
|
||||
"trust_level": 4,
|
||||
"email": "u1@linux.do",
|
||||
"avatar_url": "https://linux.do/xxxx",
|
||||
"silenced": false
|
||||
}
|
||||
"""
|
||||
|
||||
id: int | str = Field(description="用户 ID")
|
||||
username: str = Field(description="用户名")
|
||||
name: str | None = Field(default=None, description="显示名称")
|
||||
email: str | None = Field(default=None, description="邮箱")
|
||||
avatar_url: str | None = Field(default=None, description="头像 URL")
|
||||
active: bool = Field(default=True, description="是否激活")
|
||||
trust_level: int | None = Field(default=None, description="信任等级")
|
||||
silenced: bool = Field(default=False, description="是否被禁言")
|
||||
|
||||
|
||||
class OAuth2LoginResponse(BaseSchema):
|
||||
"""OAuth2 登录响应"""
|
||||
|
||||
access_token: str = Field(description="JWT 访问令牌")
|
||||
refresh_token: str = Field(description="JWT 刷新令牌")
|
||||
token_type: str = Field(default="Bearer", description="令牌类型")
|
||||
expires_in: int = Field(description="访问令牌过期时间(秒)")
|
||||
is_new_user: bool = Field(description="是否为新注册用户")
|
||||
|
||||
346
app/schemas/redeem_code.py
Normal file
346
app/schemas/redeem_code.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
兑换码相关 Schema
|
||||
|
||||
定义兑换码数据的验证和序列化规则。
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from app.schemas.base import BaseSchema, PaginatedResponse
|
||||
from app.schemas.balance import UNITS_PER_DISPLAY, display_to_units, format_display
|
||||
from app.models.redeem_code import RedeemCodeStatus
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 用户兑换 Schema
|
||||
# ============================================================
|
||||
|
||||
class RedeemRequest(BaseSchema):
|
||||
"""兑换请求"""
|
||||
|
||||
code: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
max_length=32,
|
||||
description="兑换码",
|
||||
examples=["ABCD-EFGH-JKLM-NPQR"],
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("code")
|
||||
@classmethod
|
||||
def normalize_code(cls, v: str) -> str:
|
||||
"""标准化兑换码格式"""
|
||||
# 移除空格,转大写
|
||||
return v.strip().upper().replace(" ", "")
|
||||
|
||||
|
||||
class RedeemResponse(BaseSchema):
|
||||
"""兑换响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str = "兑换成功"
|
||||
face_value: str = Field(description="兑换金额")
|
||||
balance_before: str = Field(description="兑换前余额")
|
||||
balance_after: str = Field(description="兑换后余额")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 兑换码信息 Schema
|
||||
# ============================================================
|
||||
|
||||
class RedeemCodeResponse(BaseSchema):
|
||||
"""兑换码信息响应"""
|
||||
|
||||
id: str
|
||||
code: str
|
||||
face_value_units: int = Field(description="面值(单位额度)")
|
||||
status: RedeemCodeStatus
|
||||
max_uses: int
|
||||
used_count: int
|
||||
expires_at: datetime | None
|
||||
used_at: datetime | None
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def face_value(self) -> str:
|
||||
"""显示面值(2 位小数)"""
|
||||
return format_display(self.face_value_units)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> bool:
|
||||
"""是否有效"""
|
||||
if self.status != RedeemCodeStatus.ACTIVE:
|
||||
return False
|
||||
if self.used_count >= self.max_uses:
|
||||
return False
|
||||
if self.expires_at and self.expires_at < datetime.now(self.expires_at.tzinfo):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class RedeemCodeDetailResponse(RedeemCodeResponse):
|
||||
"""兑换码详细信息响应(管理员用)"""
|
||||
|
||||
batch_id: str | None
|
||||
batch_name: str | None = None
|
||||
remark: str | None
|
||||
created_by: str | None
|
||||
used_by: str | None
|
||||
|
||||
|
||||
class RedeemCodeListResponse(PaginatedResponse[RedeemCodeDetailResponse]):
|
||||
"""兑换码列表响应"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 批次 Schema
|
||||
# ============================================================
|
||||
|
||||
class BatchCreateRequest(BaseSchema):
|
||||
"""创建兑换码批次请求"""
|
||||
|
||||
name: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
max_length=128,
|
||||
description="批次名称",
|
||||
examples=["2024年1月活动"],
|
||||
),
|
||||
]
|
||||
description: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=500,
|
||||
description="批次描述",
|
||||
),
|
||||
]
|
||||
face_value: Annotated[
|
||||
float,
|
||||
Field(
|
||||
gt=0,
|
||||
description="面值(显示金额)",
|
||||
examples=[10.00, 50.00],
|
||||
),
|
||||
]
|
||||
count: Annotated[
|
||||
int,
|
||||
Field(
|
||||
gt=0,
|
||||
le=10000,
|
||||
description="生成数量",
|
||||
examples=[100],
|
||||
),
|
||||
]
|
||||
max_uses: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=1,
|
||||
gt=0,
|
||||
le=100,
|
||||
description="每个兑换码最大使用次数",
|
||||
),
|
||||
]
|
||||
expires_at: Annotated[
|
||||
datetime | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="过期时间",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("face_value")
|
||||
@classmethod
|
||||
def validate_face_value(cls, v: float) -> float:
|
||||
"""验证面值精度"""
|
||||
if round(v, 3) != v:
|
||||
raise ValueError("面值精度不能超过 3 位小数")
|
||||
return v
|
||||
|
||||
@property
|
||||
def face_value_units(self) -> int:
|
||||
"""转换为单位额度"""
|
||||
return display_to_units(self.face_value)
|
||||
|
||||
|
||||
class BatchResponse(BaseSchema):
|
||||
"""批次信息响应"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str | None
|
||||
face_value_units: int
|
||||
total_count: int
|
||||
used_count: int
|
||||
created_by: str | None
|
||||
created_at: datetime
|
||||
|
||||
@property
|
||||
def face_value(self) -> str:
|
||||
"""显示面值(2 位小数)"""
|
||||
return format_display(self.face_value_units)
|
||||
|
||||
@property
|
||||
def unused_count(self) -> int:
|
||||
"""未使用数量"""
|
||||
return self.total_count - self.used_count
|
||||
|
||||
|
||||
class BatchDetailResponse(BatchResponse):
|
||||
"""批次详细信息响应"""
|
||||
|
||||
codes: list[RedeemCodeResponse] = []
|
||||
|
||||
|
||||
class BatchListResponse(PaginatedResponse[BatchResponse]):
|
||||
"""批次列表响应"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 导入导出 Schema
|
||||
# ============================================================
|
||||
|
||||
class ImportCodeRequest(BaseSchema):
|
||||
"""导入兑换码请求"""
|
||||
|
||||
code: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
max_length=32,
|
||||
description="兑换码",
|
||||
),
|
||||
]
|
||||
face_value: Annotated[
|
||||
float,
|
||||
Field(
|
||||
gt=0,
|
||||
description="面值",
|
||||
),
|
||||
]
|
||||
max_uses: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=1,
|
||||
gt=0,
|
||||
),
|
||||
]
|
||||
expires_at: datetime | None = None
|
||||
remark: str | None = None
|
||||
|
||||
@field_validator("code")
|
||||
@classmethod
|
||||
def normalize_code(cls, v: str) -> str:
|
||||
"""标准化兑换码"""
|
||||
return v.strip().upper()
|
||||
|
||||
@property
|
||||
def face_value_units(self) -> int:
|
||||
"""转换为单位额度"""
|
||||
return display_to_units(self.face_value)
|
||||
|
||||
|
||||
class BulkImportRequest(BaseSchema):
|
||||
"""批量导入兑换码请求"""
|
||||
|
||||
codes: Annotated[
|
||||
list[ImportCodeRequest],
|
||||
Field(
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
description="兑换码列表",
|
||||
),
|
||||
]
|
||||
batch_name: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=128,
|
||||
description="批次名称(可选)",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class BulkImportResponse(BaseSchema):
|
||||
"""批量导入响应"""
|
||||
|
||||
success_count: int
|
||||
failed_count: int
|
||||
failed_codes: list[str] = []
|
||||
batch_id: str | None = None
|
||||
|
||||
|
||||
class ExportCodeItem(BaseSchema):
|
||||
"""导出兑换码条目"""
|
||||
|
||||
code: str
|
||||
face_value: str
|
||||
status: str
|
||||
max_uses: int
|
||||
used_count: int
|
||||
expires_at: str | None
|
||||
created_at: str
|
||||
used_at: str | None
|
||||
used_by: str | None
|
||||
|
||||
|
||||
class ExportResponse(BaseSchema):
|
||||
"""导出响应"""
|
||||
|
||||
total: int
|
||||
codes: list[ExportCodeItem]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 使用日志 Schema
|
||||
# ============================================================
|
||||
|
||||
class UsageLogResponse(BaseSchema):
|
||||
"""兑换码使用日志响应"""
|
||||
|
||||
id: str
|
||||
redeem_code_id: str
|
||||
code_snapshot: str
|
||||
user_id: str
|
||||
username: str | None = None
|
||||
face_value: str
|
||||
ip_address: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class UsageLogListResponse(PaginatedResponse[UsageLogResponse]):
|
||||
"""使用日志列表响应"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 查询参数 Schema
|
||||
# ============================================================
|
||||
|
||||
class RedeemCodeQueryParams(BaseSchema):
|
||||
"""兑换码查询参数"""
|
||||
|
||||
status: RedeemCodeStatus | None = None
|
||||
batch_id: str | None = None
|
||||
code: str | None = None
|
||||
created_after: datetime | None = None
|
||||
created_before: datetime | None = None
|
||||
|
||||
|
||||
class UsageLogQueryParams(BaseSchema):
|
||||
"""使用日志查询参数"""
|
||||
|
||||
redeem_code_id: str | None = None
|
||||
user_id: str | None = None
|
||||
code: str | None = None
|
||||
created_after: datetime | None = None
|
||||
created_before: datetime | None = None
|
||||
|
||||
156
app/schemas/user.py
Normal file
156
app/schemas/user.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
用户相关 Schema
|
||||
|
||||
定义用户数据的验证和序列化规则。
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import EmailStr, Field, field_validator
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.base import BaseSchema
|
||||
|
||||
|
||||
# 用户名正则:字母开头,只允许字母、数字、下划线
|
||||
USERNAME_PATTERN = re.compile(r"^[a-zA-Z][a-zA-Z0-9_]*$")
|
||||
|
||||
|
||||
class UserBase(BaseSchema):
|
||||
"""用户基础字段"""
|
||||
|
||||
username: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=settings.username_min_length,
|
||||
max_length=settings.username_max_length,
|
||||
description="用户名(字母开头,只允许字母、数字、下划线)",
|
||||
examples=["john_doe"],
|
||||
),
|
||||
]
|
||||
email: Annotated[
|
||||
EmailStr | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="邮箱地址",
|
||||
examples=["user@example.com"],
|
||||
),
|
||||
]
|
||||
nickname: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=64,
|
||||
description="昵称",
|
||||
examples=["John"],
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("username")
|
||||
@classmethod
|
||||
def validate_username(cls, v: str) -> str:
|
||||
"""验证用户名格式"""
|
||||
if not USERNAME_PATTERN.match(v):
|
||||
raise ValueError("用户名必须以字母开头,只能包含字母、数字和下划线")
|
||||
return v.lower() # 统一转小写
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
"""用户注册请求"""
|
||||
|
||||
password: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=settings.password_min_length,
|
||||
max_length=settings.password_max_length,
|
||||
description="密码",
|
||||
examples=["SecurePass123"],
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def validate_password(cls, v: str) -> str:
|
||||
"""验证密码强度"""
|
||||
errors: list[str] = []
|
||||
|
||||
if settings.password_require_uppercase and not re.search(r"[A-Z]", v):
|
||||
errors.append("至少包含一个大写字母")
|
||||
|
||||
if settings.password_require_lowercase and not re.search(r"[a-z]", v):
|
||||
errors.append("至少包含一个小写字母")
|
||||
|
||||
if settings.password_require_digit and not re.search(r"\d", v):
|
||||
errors.append("至少包含一个数字")
|
||||
|
||||
if settings.password_require_special and not re.search(r"[!@#$%^&*(),.?\":{}|<>]", v):
|
||||
errors.append("至少包含一个特殊字符")
|
||||
|
||||
if errors:
|
||||
raise ValueError("密码强度不足:" + ";".join(errors))
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class UserUpdate(BaseSchema):
|
||||
"""用户信息更新请求"""
|
||||
|
||||
nickname: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=64,
|
||||
description="昵称",
|
||||
),
|
||||
]
|
||||
email: Annotated[
|
||||
EmailStr | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="邮箱地址",
|
||||
),
|
||||
]
|
||||
avatar_url: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=512,
|
||||
description="头像 URL",
|
||||
),
|
||||
]
|
||||
bio: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
max_length=500,
|
||||
description="个人简介",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UserResponse(BaseSchema):
|
||||
"""用户信息响应"""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
email: str | None
|
||||
nickname: str | None
|
||||
avatar_url: str | None
|
||||
bio: str | None
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_login_at: datetime | None
|
||||
|
||||
|
||||
class UserProfileResponse(BaseSchema):
|
||||
"""用户公开资料响应(不包含敏感信息)"""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
nickname: str | None
|
||||
avatar_url: str | None
|
||||
bio: str | None
|
||||
created_at: datetime
|
||||
|
||||
16
app/services/__init__.py
Normal file
16
app/services/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""业务服务层"""
|
||||
|
||||
from app.services.auth import AuthService
|
||||
from app.services.oauth2 import OAuth2Service
|
||||
from app.services.user import UserService
|
||||
from app.services.balance import BalanceService
|
||||
from app.services.redeem_code import RedeemCodeService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"OAuth2Service",
|
||||
"UserService",
|
||||
"BalanceService",
|
||||
"RedeemCodeService",
|
||||
]
|
||||
|
||||
296
app/services/auth.py
Normal file
296
app/services/auth.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
认证服务
|
||||
|
||||
处理用户认证相关的业务逻辑。
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import jwt
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
InvalidCredentialsError,
|
||||
PasswordValidationError,
|
||||
TokenError,
|
||||
TokenExpiredError,
|
||||
UserDisabledError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
password_needs_rehash,
|
||||
verify_password,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.auth import PasswordChangeRequest, TokenResponse
|
||||
from app.services.user import UserService
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""认证服务"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化认证服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
self.user_service = UserService(session)
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> User:
|
||||
"""
|
||||
验证用户凭证
|
||||
|
||||
Args:
|
||||
username: 用户名或邮箱
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
验证成功的用户对象
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: 凭证无效
|
||||
UserDisabledError: 用户被禁用
|
||||
"""
|
||||
# 查找用户(支持用户名或邮箱登录)
|
||||
user = await self.user_repo.get_by_username_or_email(username)
|
||||
|
||||
if not user:
|
||||
# 防止时序攻击:即使用户不存在也进行密码验证
|
||||
verify_password(password, "$argon2id$v=19$m=65536,t=3,p=4$dummy$dummy")
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
# 检查是否需要重新哈希密码(参数升级)
|
||||
if password_needs_rehash(user.hashed_password):
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
hashed_password=hash_password(password),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
return user
|
||||
|
||||
async def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> tuple[User, TokenResponse]:
|
||||
"""
|
||||
用户登录
|
||||
|
||||
Args:
|
||||
username: 用户名或邮箱
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
(用户对象, 令牌响应)
|
||||
"""
|
||||
user = await self.authenticate(username, password)
|
||||
|
||||
# 更新最后登录时间
|
||||
await self.user_service.update_last_login(user)
|
||||
|
||||
# 生成令牌
|
||||
tokens = self._create_tokens(user)
|
||||
|
||||
return user, tokens
|
||||
|
||||
def _create_tokens(self, user: User) -> TokenResponse:
|
||||
"""
|
||||
为用户创建访问令牌和刷新令牌
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
|
||||
Returns:
|
||||
令牌响应
|
||||
"""
|
||||
access_token = create_access_token(
|
||||
subject=user.id,
|
||||
extra_claims={
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
},
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(subject=user.id)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
|
||||
"""
|
||||
刷新访问令牌
|
||||
|
||||
Args:
|
||||
refresh_token: 刷新令牌
|
||||
|
||||
Returns:
|
||||
新的令牌响应
|
||||
|
||||
Raises:
|
||||
TokenError: 令牌无效
|
||||
TokenExpiredError: 令牌已过期
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(refresh_token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError()
|
||||
except jwt.InvalidTokenError:
|
||||
raise TokenError()
|
||||
|
||||
# 验证令牌类型
|
||||
if payload.get("type") != "refresh":
|
||||
raise TokenError("无效的令牌类型")
|
||||
|
||||
# 获取用户
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise TokenError()
|
||||
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
# 生成新令牌
|
||||
return self._create_tokens(user)
|
||||
|
||||
async def change_password(
|
||||
self,
|
||||
user_id: str,
|
||||
password_data: PasswordChangeRequest,
|
||||
) -> None:
|
||||
"""
|
||||
修改用户密码
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
password_data: 密码修改数据
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: 用户不存在
|
||||
InvalidCredentialsError: 当前密码错误
|
||||
PasswordValidationError: 新密码不符合要求
|
||||
"""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
# 验证当前密码
|
||||
if not verify_password(password_data.current_password, user.hashed_password):
|
||||
raise InvalidCredentialsError("当前密码错误")
|
||||
|
||||
# 验证新密码强度
|
||||
self._validate_password_strength(password_data.new_password)
|
||||
|
||||
# 更新密码
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
hashed_password=hash_password(password_data.new_password),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
def _validate_password_strength(self, password: str) -> None:
|
||||
"""
|
||||
验证密码强度
|
||||
|
||||
Args:
|
||||
password: 密码
|
||||
|
||||
Raises:
|
||||
PasswordValidationError: 密码不符合要求
|
||||
"""
|
||||
import re
|
||||
|
||||
errors: list[str] = []
|
||||
|
||||
if len(password) < settings.password_min_length:
|
||||
errors.append(f"密码长度不能少于 {settings.password_min_length} 位")
|
||||
|
||||
if len(password) > settings.password_max_length:
|
||||
errors.append(f"密码长度不能超过 {settings.password_max_length} 位")
|
||||
|
||||
if settings.password_require_uppercase and not re.search(r"[A-Z]", password):
|
||||
errors.append("至少包含一个大写字母")
|
||||
|
||||
if settings.password_require_lowercase and not re.search(r"[a-z]", password):
|
||||
errors.append("至少包含一个小写字母")
|
||||
|
||||
if settings.password_require_digit and not re.search(r"\d", password):
|
||||
errors.append("至少包含一个数字")
|
||||
|
||||
if settings.password_require_special and not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
|
||||
errors.append("至少包含一个特殊字符")
|
||||
|
||||
if errors:
|
||||
raise PasswordValidationError(";".join(errors))
|
||||
|
||||
async def get_current_user(self, token: str) -> User:
|
||||
"""
|
||||
从令牌获取当前用户
|
||||
|
||||
Args:
|
||||
token: 访问令牌
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
|
||||
Raises:
|
||||
TokenError: 令牌无效
|
||||
TokenExpiredError: 令牌已过期
|
||||
UserNotFoundError: 用户不存在
|
||||
UserDisabledError: 用户被禁用
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError()
|
||||
except jwt.InvalidTokenError:
|
||||
raise TokenError()
|
||||
|
||||
# 验证令牌类型
|
||||
if payload.get("type") != "access":
|
||||
raise TokenError("无效的令牌类型")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise TokenError()
|
||||
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
return user
|
||||
|
||||
934
app/services/balance.py
Normal file
934
app/services/balance.py
Normal file
@@ -0,0 +1,934 @@
|
||||
"""
|
||||
余额服务
|
||||
|
||||
处理余额相关的业务逻辑。
|
||||
|
||||
设计说明:
|
||||
- 所有金额操作使用整数单位(units),避免浮点精度问题
|
||||
- 扣款操作使用行级锁(悲观锁)确保原子性
|
||||
- 充值操作使用乐观锁,配合重试机制
|
||||
- 每笔操作都记录交易流水
|
||||
|
||||
预扣款流程(内部方法,用于耗时付费操作):
|
||||
1. pre_authorize() - 预扣款,冻结金额,快速释放锁,返回交易ID
|
||||
2. 执行耗时的付费操作(使用交易ID追踪)
|
||||
3. confirm() 或 cancel() - 根据操作结果确认或取消
|
||||
|
||||
推荐使用上下文管理器 deduction_context() 自动处理确认/取消。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncIterator, Callable, Awaitable, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import (
|
||||
AppException,
|
||||
ResourceNotFoundError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.models.balance import (
|
||||
UserBalance,
|
||||
BalanceTransaction,
|
||||
TransactionType,
|
||||
TransactionStatus,
|
||||
)
|
||||
from app.repositories.balance import BalanceRepository, TransactionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InsufficientBalanceError(AppException):
|
||||
"""余额不足"""
|
||||
|
||||
def __init__(self, required: int, available: int):
|
||||
super().__init__(
|
||||
f"余额不足,需要 {required / 1000:.2f},当前可用 {available / 1000:.2f}",
|
||||
"INSUFFICIENT_BALANCE",
|
||||
{"required_units": required, "available_units": available},
|
||||
)
|
||||
|
||||
|
||||
class DuplicateTransactionError(AppException):
|
||||
"""重复交易"""
|
||||
|
||||
def __init__(self, idempotency_key: str):
|
||||
super().__init__(
|
||||
"该交易已处理",
|
||||
"DUPLICATE_TRANSACTION",
|
||||
{"idempotency_key": idempotency_key},
|
||||
)
|
||||
|
||||
|
||||
class ConcurrencyError(AppException):
|
||||
"""并发冲突"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"操作冲突,请重试",
|
||||
"CONCURRENCY_ERROR",
|
||||
)
|
||||
|
||||
|
||||
class TransactionNotFoundError(AppException):
|
||||
"""交易不存在"""
|
||||
|
||||
def __init__(self, transaction_id: str):
|
||||
super().__init__(
|
||||
"交易记录不存在",
|
||||
"TRANSACTION_NOT_FOUND",
|
||||
{"transaction_id": transaction_id},
|
||||
)
|
||||
|
||||
|
||||
class TransactionStateError(AppException):
|
||||
"""交易状态错误"""
|
||||
|
||||
def __init__(self, transaction_id: str, current_status: str, expected_status: str = "pending"):
|
||||
super().__init__(
|
||||
f"交易状态无效:当前 {current_status},预期 {expected_status}",
|
||||
"TRANSACTION_STATE_ERROR",
|
||||
{
|
||||
"transaction_id": transaction_id,
|
||||
"current_status": current_status,
|
||||
"expected_status": expected_status,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreAuthResult:
|
||||
"""
|
||||
预授权结果
|
||||
|
||||
包含交易ID和相关信息,用于后续确认或取消操作。
|
||||
"""
|
||||
|
||||
transaction_id: str
|
||||
"""交易ID,用于后续 confirm/cancel 操作"""
|
||||
|
||||
user_id: str
|
||||
"""用户ID"""
|
||||
|
||||
amount_units: int
|
||||
"""预扣款金额(单位额度)"""
|
||||
|
||||
frozen_at: datetime
|
||||
"""冻结时间"""
|
||||
|
||||
@property
|
||||
def amount_display(self) -> str:
|
||||
"""显示金额(2位小数)"""
|
||||
return f"{self.amount_units / 1000:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeductionResult:
|
||||
"""
|
||||
扣款结果
|
||||
|
||||
包含扣款操作的完整信息。
|
||||
"""
|
||||
|
||||
transaction_id: str
|
||||
"""交易ID"""
|
||||
|
||||
status: TransactionStatus
|
||||
"""交易状态"""
|
||||
|
||||
amount_units: int
|
||||
"""实际扣款金额(单位额度)"""
|
||||
|
||||
balance_before: int
|
||||
"""扣款前余额"""
|
||||
|
||||
balance_after: int
|
||||
"""扣款后余额"""
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
"""是否扣款成功"""
|
||||
return self.status == TransactionStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def amount_display(self) -> str:
|
||||
"""显示金额"""
|
||||
return f"{abs(self.amount_units) / 1000:.2f}"
|
||||
|
||||
@property
|
||||
def balance_before_display(self) -> str:
|
||||
"""显示扣款前余额"""
|
||||
return f"{self.balance_before / 1000:.2f}"
|
||||
|
||||
@property
|
||||
def balance_after_display(self) -> str:
|
||||
"""显示扣款后余额"""
|
||||
return f"{self.balance_after / 1000:.2f}"
|
||||
|
||||
|
||||
class BalanceService:
|
||||
"""余额服务"""
|
||||
|
||||
# 乐观锁最大重试次数
|
||||
MAX_RETRIES = 3
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化余额服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.balance_repo = BalanceRepository(session)
|
||||
self.transaction_repo = TransactionRepository(session)
|
||||
|
||||
# ============================================================
|
||||
# 余额查询
|
||||
# ============================================================
|
||||
|
||||
async def get_balance(self, user_id: str) -> UserBalance:
|
||||
"""
|
||||
获取用户余额
|
||||
|
||||
如果用户没有余额账户,自动创建一个。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户
|
||||
"""
|
||||
balance = await self.balance_repo.get_or_create(user_id)
|
||||
await self.balance_repo.commit()
|
||||
return balance
|
||||
|
||||
async def get_balance_detail(self, user_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
获取用户余额详情
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额详情字典
|
||||
"""
|
||||
balance = await self.get_balance(user_id)
|
||||
return {
|
||||
"user_id": balance.user_id,
|
||||
"balance_units": balance.balance,
|
||||
"frozen_units": balance.frozen_balance,
|
||||
"available_units": balance.available_balance,
|
||||
"total_recharged_units": balance.total_recharged,
|
||||
"total_consumed_units": balance.total_consumed,
|
||||
}
|
||||
|
||||
async def get_transactions(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
transaction_type: TransactionType | None = None,
|
||||
) -> tuple[list[BalanceTransaction], int]:
|
||||
"""
|
||||
获取用户交易记录
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
transaction_type: 交易类型过滤
|
||||
|
||||
Returns:
|
||||
(交易记录列表, 总数)
|
||||
"""
|
||||
transactions = await self.transaction_repo.get_by_user_id(
|
||||
user_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
transaction_type=transaction_type,
|
||||
)
|
||||
total = await self.transaction_repo.count_by_user_id(
|
||||
user_id,
|
||||
transaction_type=transaction_type,
|
||||
)
|
||||
return transactions, total
|
||||
|
||||
# ============================================================
|
||||
# 扣款操作(使用行级锁 - 悲观锁)
|
||||
# ============================================================
|
||||
|
||||
async def deduct(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BalanceTransaction:
|
||||
"""
|
||||
扣款
|
||||
|
||||
使用行级锁确保原子性,防止并发扣款导致余额变负。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度,正数)
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
idempotency_key: 幂等键
|
||||
|
||||
Returns:
|
||||
交易记录
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
DuplicateTransactionError: 重复交易
|
||||
"""
|
||||
if amount_units <= 0:
|
||||
raise ValidationError("扣款金额必须大于 0")
|
||||
|
||||
# 检查幂等性
|
||||
if idempotency_key:
|
||||
existing = await self.transaction_repo.get_by_idempotency_key(
|
||||
idempotency_key
|
||||
)
|
||||
if existing:
|
||||
raise DuplicateTransactionError(idempotency_key)
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_or_create_for_update(user_id)
|
||||
|
||||
# 检查可用余额
|
||||
if balance.available_balance < amount_units:
|
||||
raise InsufficientBalanceError(amount_units, balance.available_balance)
|
||||
|
||||
# 记录扣款前余额
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行扣款
|
||||
balance.balance -= amount_units
|
||||
balance.total_consumed += amount_units
|
||||
balance.version += 1
|
||||
|
||||
# 创建交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.DEDUCTION,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount=-amount_units, # 负数表示支出
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 扣款成功: {amount_units} 单位, "
|
||||
f"余额 {balance_before} -> {balance.balance}"
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
# ============================================================
|
||||
# 预扣款流程(内部方法,用于耗时付费操作)
|
||||
# ============================================================
|
||||
|
||||
async def pre_authorize(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> PreAuthResult:
|
||||
"""
|
||||
预授权扣款(内部方法)
|
||||
|
||||
冻结指定金额,快速释放数据库锁,返回交易ID供后续操作使用。
|
||||
此方法设计用于耗时的付费操作场景。
|
||||
|
||||
使用流程:
|
||||
1. 调用 pre_authorize() 获取 PreAuthResult
|
||||
2. 执行可能失败的耗时操作
|
||||
3. 根据操作结果调用 confirm() 或 cancel()
|
||||
|
||||
推荐使用 deduction_context() 上下文管理器自动处理。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度,正数)
|
||||
reference_type: 关联业务类型(如 api_call, service)
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
|
||||
Returns:
|
||||
PreAuthResult: 预授权结果,包含交易ID
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
ValidationError: 参数无效
|
||||
"""
|
||||
if amount_units <= 0:
|
||||
raise ValidationError("预扣款金额必须大于 0")
|
||||
|
||||
# 获取余额账户并加锁(短暂持有)
|
||||
balance = await self.balance_repo.get_or_create_for_update(user_id)
|
||||
|
||||
# 检查可用余额
|
||||
if balance.available_balance < amount_units:
|
||||
raise InsufficientBalanceError(amount_units, balance.available_balance)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 执行冻结
|
||||
balance.frozen_balance += amount_units
|
||||
balance.version += 1
|
||||
|
||||
# 创建待处理交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.DEDUCTION,
|
||||
status=TransactionStatus.PENDING,
|
||||
amount=-amount_units,
|
||||
balance_before=balance.balance,
|
||||
balance_after=balance.balance, # 尚未实际扣款
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
remark=f"预授权冻结: {amount_units} 单位",
|
||||
)
|
||||
|
||||
# 快速提交释放锁
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 预授权成功: {amount_units} 单位, "
|
||||
f"交易ID: {transaction.id}"
|
||||
)
|
||||
|
||||
return PreAuthResult(
|
||||
transaction_id=transaction.id,
|
||||
user_id=user_id,
|
||||
amount_units=amount_units,
|
||||
frozen_at=now,
|
||||
)
|
||||
|
||||
async def confirm(
|
||||
self,
|
||||
transaction_id: str,
|
||||
*,
|
||||
actual_amount_units: int | None = None,
|
||||
) -> DeductionResult:
|
||||
"""
|
||||
确认预授权扣款(内部方法)
|
||||
|
||||
将预冻结的金额实际扣除。支持部分扣款。
|
||||
|
||||
Args:
|
||||
transaction_id: 预授权交易 ID
|
||||
actual_amount_units: 实际扣款金额(可选,用于部分扣款,默认全额)
|
||||
|
||||
Returns:
|
||||
DeductionResult: 扣款结果
|
||||
|
||||
Raises:
|
||||
TransactionNotFoundError: 交易不存在
|
||||
TransactionStateError: 交易状态不是 PENDING
|
||||
ValidationError: 参数无效
|
||||
"""
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
if not transaction:
|
||||
raise TransactionNotFoundError(transaction_id)
|
||||
|
||||
if transaction.status != TransactionStatus.PENDING:
|
||||
raise TransactionStateError(
|
||||
transaction_id,
|
||||
transaction.status.value,
|
||||
)
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_by_user_id_for_update(transaction.user_id)
|
||||
if not balance:
|
||||
raise ResourceNotFoundError("余额账户不存在")
|
||||
|
||||
frozen_amount = abs(transaction.amount)
|
||||
|
||||
# 确定实际扣款金额
|
||||
if actual_amount_units is not None:
|
||||
if actual_amount_units <= 0:
|
||||
raise ValidationError("实际扣款金额必须大于 0")
|
||||
if actual_amount_units > frozen_amount:
|
||||
raise ValidationError(
|
||||
f"实际扣款金额 ({actual_amount_units}) 不能超过预授权金额 ({frozen_amount})"
|
||||
)
|
||||
deduct_amount = actual_amount_units
|
||||
else:
|
||||
deduct_amount = frozen_amount
|
||||
|
||||
# 检查冻结金额
|
||||
if balance.frozen_balance < frozen_amount:
|
||||
raise ValidationError("冻结金额不足,可能已被其他操作修改")
|
||||
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行扣款:解冻全部,扣除实际金额
|
||||
balance.frozen_balance -= frozen_amount
|
||||
balance.balance -= deduct_amount
|
||||
balance.total_consumed += deduct_amount
|
||||
balance.version += 1
|
||||
|
||||
# 更新交易记录
|
||||
transaction.status = TransactionStatus.COMPLETED
|
||||
transaction.amount = -deduct_amount # 更新为实际扣款金额
|
||||
transaction.balance_after = balance.balance
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {transaction.user_id} 确认扣款: {deduct_amount} 单位, "
|
||||
f"余额 {balance_before} -> {balance.balance}"
|
||||
)
|
||||
|
||||
return DeductionResult(
|
||||
transaction_id=transaction.id,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount_units=deduct_amount,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
)
|
||||
|
||||
async def cancel(
|
||||
self,
|
||||
transaction_id: str,
|
||||
*,
|
||||
reason: str | None = None,
|
||||
) -> DeductionResult:
|
||||
"""
|
||||
取消预授权扣款(内部方法)
|
||||
|
||||
解冻预授权的金额,退回用户可用余额。
|
||||
|
||||
Args:
|
||||
transaction_id: 预授权交易 ID
|
||||
reason: 取消原因(可选,记录在日志中)
|
||||
|
||||
Returns:
|
||||
DeductionResult: 取消结果
|
||||
|
||||
Raises:
|
||||
TransactionNotFoundError: 交易不存在
|
||||
TransactionStateError: 交易状态不是 PENDING
|
||||
"""
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
if not transaction:
|
||||
raise TransactionNotFoundError(transaction_id)
|
||||
|
||||
if transaction.status != TransactionStatus.PENDING:
|
||||
raise TransactionStateError(
|
||||
transaction_id,
|
||||
transaction.status.value,
|
||||
)
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_by_user_id_for_update(transaction.user_id)
|
||||
if not balance:
|
||||
raise ResourceNotFoundError("余额账户不存在")
|
||||
|
||||
frozen_amount = abs(transaction.amount)
|
||||
|
||||
# 解冻
|
||||
balance.frozen_balance -= frozen_amount
|
||||
balance.version += 1
|
||||
|
||||
# 更新交易记录
|
||||
transaction.status = TransactionStatus.CANCELLED
|
||||
if reason:
|
||||
transaction.remark = f"{transaction.remark or ''}; 取消原因: {reason}"
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {transaction.user_id} 取消预授权: {frozen_amount} 单位"
|
||||
+ (f", 原因: {reason}" if reason else "")
|
||||
)
|
||||
|
||||
return DeductionResult(
|
||||
transaction_id=transaction.id,
|
||||
status=TransactionStatus.CANCELLED,
|
||||
amount_units=0, # 实际未扣款
|
||||
balance_before=balance.balance,
|
||||
balance_after=balance.balance,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def deduction_context(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
auto_cancel_on_error: bool = True,
|
||||
) -> AsyncIterator[PreAuthResult]:
|
||||
"""
|
||||
扣款上下文管理器(推荐使用)
|
||||
|
||||
提供简便的预扣款流程,自动处理确认和取消。
|
||||
异常时自动取消预授权,退回冻结金额。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
async with balance_service.deduction_context(
|
||||
user_id,
|
||||
1000, # 扣款金额(单位额度)
|
||||
reference_type="api_call",
|
||||
description="API调用费用",
|
||||
) as pre_auth:
|
||||
# pre_auth.transaction_id 可用于追踪
|
||||
# 执行可能失败的耗时操作
|
||||
result = await call_external_api()
|
||||
if not result.success:
|
||||
raise Exception("API 调用失败")
|
||||
# 成功退出时自动确认扣款
|
||||
# 异常退出时自动取消预授权(如果 auto_cancel_on_error=True)
|
||||
```
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度)
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
auto_cancel_on_error: 异常时是否自动取消(默认 True)
|
||||
|
||||
Yields:
|
||||
PreAuthResult: 预授权结果,包含交易ID
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
"""
|
||||
# 第一阶段:预授权
|
||||
pre_auth = await self.pre_authorize(
|
||||
user_id,
|
||||
amount_units,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
try:
|
||||
yield pre_auth
|
||||
# 正常退出:确认扣款
|
||||
await self.confirm(pre_auth.transaction_id)
|
||||
except Exception as e:
|
||||
# 异常退出:取消预授权
|
||||
if auto_cancel_on_error:
|
||||
try:
|
||||
await self.cancel(
|
||||
pre_auth.transaction_id,
|
||||
reason=f"操作失败: {str(e)[:200]}",
|
||||
)
|
||||
except Exception as cancel_error:
|
||||
logger.error(
|
||||
f"取消预授权失败: {pre_auth.transaction_id}, "
|
||||
f"错误: {cancel_error}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def execute_with_deduction(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
operation: Callable[[PreAuthResult], Awaitable[T]],
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> tuple[DeductionResult, T]:
|
||||
"""
|
||||
执行带扣款的操作(函数式接口)
|
||||
|
||||
预扣款后执行指定操作,根据操作结果自动确认或取消。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
async def call_api(pre_auth: PreAuthResult):
|
||||
return await external_api.call(
|
||||
transaction_id=pre_auth.transaction_id,
|
||||
amount=pre_auth.amount_display,
|
||||
)
|
||||
|
||||
deduction_result, api_result = await balance_service.execute_with_deduction(
|
||||
user_id,
|
||||
1000,
|
||||
call_api,
|
||||
reference_type="api_call",
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度)
|
||||
operation: 要执行的异步操作,接收 PreAuthResult 参数
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
|
||||
Returns:
|
||||
(DeductionResult, operation返回值): 扣款结果和操作结果
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
Exception: 操作抛出的异常(预授权会自动取消)
|
||||
"""
|
||||
pre_auth = await self.pre_authorize(
|
||||
user_id,
|
||||
amount_units,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
try:
|
||||
# 执行操作
|
||||
result = await operation(pre_auth)
|
||||
# 成功:确认扣款
|
||||
deduction_result = await self.confirm(pre_auth.transaction_id)
|
||||
return deduction_result, result
|
||||
except Exception as e:
|
||||
# 失败:取消预授权
|
||||
try:
|
||||
await self.cancel(
|
||||
pre_auth.transaction_id,
|
||||
reason=f"操作失败: {str(e)[:200]}",
|
||||
)
|
||||
except Exception as cancel_error:
|
||||
logger.error(
|
||||
f"取消预授权失败: {pre_auth.transaction_id}, "
|
||||
f"错误: {cancel_error}"
|
||||
)
|
||||
raise
|
||||
|
||||
# ============================================================
|
||||
# 兼容方法(保留旧接口)
|
||||
# ============================================================
|
||||
|
||||
async def deduct_with_freeze(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
冻结并预扣款(兼容方法,推荐使用 pre_authorize)
|
||||
|
||||
Returns:
|
||||
交易ID
|
||||
"""
|
||||
result = await self.pre_authorize(
|
||||
user_id,
|
||||
amount_units,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
)
|
||||
return result.transaction_id
|
||||
|
||||
async def confirm_frozen_deduction(self, transaction_id: str) -> BalanceTransaction:
|
||||
"""
|
||||
确认冻结扣款(兼容方法,推荐使用 confirm)
|
||||
"""
|
||||
await self.confirm(transaction_id)
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
return transaction # type: ignore
|
||||
|
||||
async def cancel_frozen_deduction(self, transaction_id: str) -> BalanceTransaction:
|
||||
"""
|
||||
取消冻结扣款(兼容方法,推荐使用 cancel)
|
||||
"""
|
||||
await self.cancel(transaction_id)
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
return transaction # type: ignore
|
||||
|
||||
# ============================================================
|
||||
# 充值操作(使用乐观锁 + 重试)
|
||||
# ============================================================
|
||||
|
||||
async def recharge(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BalanceTransaction:
|
||||
"""
|
||||
充值
|
||||
|
||||
使用乐观锁,配合重试机制处理并发冲突。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 充值金额(单位额度,正数)
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
idempotency_key: 幂等键
|
||||
|
||||
Returns:
|
||||
交易记录
|
||||
|
||||
Raises:
|
||||
DuplicateTransactionError: 重复交易
|
||||
ConcurrencyError: 并发冲突(重试失败)
|
||||
"""
|
||||
if amount_units <= 0:
|
||||
raise ValidationError("充值金额必须大于 0")
|
||||
|
||||
# 检查幂等性
|
||||
if idempotency_key:
|
||||
existing = await self.transaction_repo.get_by_idempotency_key(
|
||||
idempotency_key
|
||||
)
|
||||
if existing:
|
||||
raise DuplicateTransactionError(idempotency_key)
|
||||
|
||||
# 乐观锁重试
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
balance = await self.balance_repo.get_or_create(user_id)
|
||||
balance_before = balance.balance
|
||||
|
||||
# 尝试更新余额
|
||||
success = await self.balance_repo.update_balance_optimistic(
|
||||
balance,
|
||||
amount_units,
|
||||
is_recharge=True,
|
||||
)
|
||||
|
||||
if success:
|
||||
# 创建交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.RECHARGE,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount=amount_units, # 正数表示收入
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 充值成功: {amount_units} 单位, "
|
||||
f"余额 {balance_before} -> {balance.balance}"
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
# 冲突,重试
|
||||
logger.warning(
|
||||
f"用户 {user_id} 充值冲突,重试 {attempt + 1}/{self.MAX_RETRIES}"
|
||||
)
|
||||
await self.balance_repo.rollback()
|
||||
|
||||
# 重试失败
|
||||
raise ConcurrencyError()
|
||||
|
||||
# ============================================================
|
||||
# 管理员操作
|
||||
# ============================================================
|
||||
|
||||
async def admin_adjust(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
operator_id: str,
|
||||
reason: str,
|
||||
) -> BalanceTransaction:
|
||||
"""
|
||||
管理员调整余额
|
||||
|
||||
Args:
|
||||
user_id: 目标用户 ID
|
||||
amount_units: 调整金额(正数增加,负数减少)
|
||||
operator_id: 操作人 ID
|
||||
reason: 调整原因
|
||||
|
||||
Returns:
|
||||
交易记录
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 减少金额时余额不足
|
||||
"""
|
||||
if amount_units == 0:
|
||||
raise ValidationError("调整金额不能为 0")
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_or_create_for_update(user_id)
|
||||
|
||||
# 减少时检查余额
|
||||
if amount_units < 0 and balance.available_balance < abs(amount_units):
|
||||
raise InsufficientBalanceError(
|
||||
abs(amount_units), balance.available_balance
|
||||
)
|
||||
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行调整
|
||||
balance.balance += amount_units
|
||||
if amount_units > 0:
|
||||
balance.total_recharged += amount_units
|
||||
balance.version += 1
|
||||
|
||||
# 创建交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.ADJUSTMENT,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount=amount_units,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
description=reason,
|
||||
operator_id=operator_id,
|
||||
remark=f"管理员调整: {reason}",
|
||||
)
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"管理员 {operator_id} 调整用户 {user_id} 余额: {amount_units} 单位, "
|
||||
f"原因: {reason}"
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
395
app/services/oauth2.py
Normal file
395
app/services/oauth2.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
OAuth2 认证服务
|
||||
|
||||
处理 OAuth2 认证流程,支持主备端点自动切换。
|
||||
当首选端点不可达时,自动回退到备用端点。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError,
|
||||
ResourceConflictError,
|
||||
)
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.auth import TokenResponse
|
||||
from app.schemas.oauth2 import OAuth2TokenData, OAuth2UserInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth2 提供商标识
|
||||
OAUTH_PROVIDER_LINUXDO = "linuxdo"
|
||||
|
||||
|
||||
class OAuth2EndpointError(AuthenticationError):
|
||||
"""OAuth2 端点错误"""
|
||||
|
||||
def __init__(self, message: str = "OAuth2 服务不可用"):
|
||||
super().__init__(message, "OAUTH2_ENDPOINT_ERROR")
|
||||
|
||||
|
||||
class OAuth2StateError(AuthenticationError):
|
||||
"""OAuth2 状态验证错误"""
|
||||
|
||||
def __init__(self, message: str = "无效的状态码"):
|
||||
super().__init__(message, "OAUTH2_STATE_ERROR")
|
||||
|
||||
|
||||
class OAuth2Service:
|
||||
"""
|
||||
OAuth2 认证服务
|
||||
|
||||
特性:
|
||||
- 支持主备端点自动切换
|
||||
- 首选端点请求失败时自动回退到备用端点
|
||||
- 状态码验证防止 CSRF 攻击
|
||||
"""
|
||||
|
||||
# 存储状态码(生产环境应使用 Redis)
|
||||
_state_store: dict[str, bool] = {}
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化 OAuth2 服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
|
||||
# 端点配置
|
||||
self._endpoints = {
|
||||
"authorize": {
|
||||
"primary": settings.oauth2_authorize_endpoint,
|
||||
"reserve": settings.oauth2_authorize_endpoint_reserve,
|
||||
},
|
||||
"token": {
|
||||
"primary": settings.oauth2_token_endpoint,
|
||||
"reserve": settings.oauth2_token_endpoint_reserve,
|
||||
},
|
||||
"userinfo": {
|
||||
"primary": settings.oauth2_user_info_endpoint,
|
||||
"reserve": settings.oauth2_user_info_endpoint_reserve,
|
||||
},
|
||||
}
|
||||
|
||||
self._timeout = settings.oauth2_request_timeout
|
||||
|
||||
def generate_authorize_url(self, redirect_uri: str) -> tuple[str, str]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
|
||||
Args:
|
||||
redirect_uri: 回调 URL
|
||||
|
||||
Returns:
|
||||
(授权 URL, 状态码)
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
self._state_store[state] = True # 存储状态码
|
||||
|
||||
params = {
|
||||
"client_id": settings.oauth2_client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
"scope": "read", # 根据实际需要调整
|
||||
}
|
||||
|
||||
# 使用首选授权端点
|
||||
authorize_url = f"{self._endpoints['authorize']['primary']}?{urlencode(params)}"
|
||||
|
||||
return authorize_url, state
|
||||
|
||||
def validate_state(self, state: str) -> bool:
|
||||
"""
|
||||
验证状态码(防 CSRF)
|
||||
|
||||
Args:
|
||||
state: 状态码
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
if state in self._state_store:
|
||||
del self._state_store[state] # 使用后立即删除
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _request_with_fallback(
|
||||
self,
|
||||
endpoint_type: str,
|
||||
method: str,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
带回退的 HTTP 请求
|
||||
|
||||
首先尝试首选端点,失败后自动切换到备用端点。
|
||||
|
||||
Args:
|
||||
endpoint_type: 端点类型(token/userinfo)
|
||||
method: HTTP 方法
|
||||
**kwargs: 请求参数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 所有端点都不可用
|
||||
"""
|
||||
endpoints = self._endpoints[endpoint_type]
|
||||
last_error: Exception | None = None
|
||||
|
||||
for endpoint_name, url in [("primary", endpoints["primary"]), ("reserve", endpoints["reserve"])]:
|
||||
try:
|
||||
logger.debug(f"尝试 OAuth2 {endpoint_type} 端点 ({endpoint_name}): {url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
if method.upper() == "POST":
|
||||
response = await client.post(url, **kwargs)
|
||||
else:
|
||||
response = await client.get(url, **kwargs)
|
||||
|
||||
# 检查 HTTP 状态
|
||||
if response.status_code >= 500:
|
||||
logger.warning(
|
||||
f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) "
|
||||
f"返回服务器错误: {response.status_code}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"OAuth2 {endpoint_type} 请求成功 ({endpoint_name})")
|
||||
return response
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 超时: {e}")
|
||||
last_error = e
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 连接失败: {e}")
|
||||
last_error = e
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 请求异常: {e}")
|
||||
last_error = e
|
||||
|
||||
# 所有端点都失败
|
||||
error_msg = f"OAuth2 {endpoint_type} 服务不可用"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
raise OAuth2EndpointError(error_msg)
|
||||
|
||||
async def exchange_code_for_token(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuth2TokenData:
|
||||
"""
|
||||
用授权码换取访问令牌
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
redirect_uri: 回调 URL(必须与授权时一致)
|
||||
|
||||
Returns:
|
||||
OAuth2 令牌数据
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 端点不可用
|
||||
AuthenticationError: 换取令牌失败
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": settings.oauth2_client_id,
|
||||
"client_secret": settings.oauth2_client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
response = await self._request_with_fallback(
|
||||
"token",
|
||||
"POST",
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth2 token 响应错误: {response.status_code} - {response.text}")
|
||||
raise AuthenticationError(
|
||||
f"获取访问令牌失败: {response.status_code}",
|
||||
"OAUTH2_TOKEN_ERROR",
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
return OAuth2TokenData(**token_data)
|
||||
|
||||
async def get_user_info(self, access_token: str) -> OAuth2UserInfo:
|
||||
"""
|
||||
获取 OAuth2 用户信息
|
||||
|
||||
Args:
|
||||
access_token: OAuth2 访问令牌
|
||||
|
||||
Returns:
|
||||
用户信息
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 端点不可用
|
||||
AuthenticationError: 获取用户信息失败
|
||||
"""
|
||||
response = await self._request_with_fallback(
|
||||
"userinfo",
|
||||
"GET",
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth2 userinfo 响应错误: {response.status_code} - {response.text}")
|
||||
raise AuthenticationError(
|
||||
f"获取用户信息失败: {response.status_code}",
|
||||
"OAUTH2_USERINFO_ERROR",
|
||||
)
|
||||
|
||||
user_data = response.json()
|
||||
return OAuth2UserInfo(**user_data)
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> tuple[User, TokenResponse, bool]:
|
||||
"""
|
||||
完整的 OAuth2 认证流程
|
||||
|
||||
1. 验证状态码
|
||||
2. 用授权码换取令牌
|
||||
3. 获取用户信息
|
||||
4. 创建或更新用户
|
||||
5. 生成 JWT 令牌
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
state: 状态码
|
||||
redirect_uri: 回调 URL
|
||||
|
||||
Returns:
|
||||
(用户对象, JWT 令牌响应, 是否新用户)
|
||||
|
||||
Raises:
|
||||
OAuth2StateError: 状态码无效
|
||||
OAuth2EndpointError: OAuth2 服务不可用
|
||||
AuthenticationError: 认证失败
|
||||
"""
|
||||
# 1. 验证状态码
|
||||
if not self.validate_state(state):
|
||||
raise OAuth2StateError()
|
||||
|
||||
# 2. 换取令牌
|
||||
oauth_token = await self.exchange_code_for_token(code, redirect_uri)
|
||||
|
||||
# 3. 获取用户信息
|
||||
oauth_user = await self.get_user_info(oauth_token.access_token)
|
||||
|
||||
# 4. 查找或创建用户
|
||||
user, is_new_user = await self._get_or_create_user(oauth_user)
|
||||
|
||||
# 5. 生成 JWT 令牌
|
||||
tokens = self._create_tokens(user)
|
||||
|
||||
return user, tokens, is_new_user
|
||||
|
||||
async def _get_or_create_user(
|
||||
self,
|
||||
oauth_user: OAuth2UserInfo,
|
||||
) -> tuple[User, bool]:
|
||||
"""
|
||||
根据 OAuth2 用户信息获取或创建本地用户
|
||||
|
||||
Args:
|
||||
oauth_user: OAuth2 用户信息
|
||||
|
||||
Returns:
|
||||
(用户对象, 是否新创建)
|
||||
"""
|
||||
oauth_user_id = str(oauth_user.id)
|
||||
|
||||
# 先通过 OAuth ID 查找
|
||||
user = await self.user_repo.get_by_oauth(
|
||||
provider=OAUTH_PROVIDER_LINUXDO,
|
||||
oauth_user_id=oauth_user_id,
|
||||
)
|
||||
|
||||
if user:
|
||||
# 更新用户信息(头像等可能变化)
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
nickname=oauth_user.name or oauth_user.username,
|
||||
avatar_url=oauth_user.avatar_url,
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
return user, False
|
||||
|
||||
# 检查用户名是否已存在
|
||||
username = oauth_user.username.lower()
|
||||
existing_user = await self.user_repo.get_by_username(username)
|
||||
if existing_user:
|
||||
# 用户名冲突,添加后缀
|
||||
username = f"{username}_{oauth_user_id[:8]}"
|
||||
|
||||
# 创建新用户
|
||||
user = await self.user_repo.create(
|
||||
username=username,
|
||||
email=oauth_user.email,
|
||||
nickname=oauth_user.name or oauth_user.username,
|
||||
avatar_url=oauth_user.avatar_url,
|
||||
oauth_provider=OAUTH_PROVIDER_LINUXDO,
|
||||
oauth_user_id=oauth_user_id,
|
||||
hashed_password=None, # OAuth2 用户无密码
|
||||
is_active=oauth_user.active,
|
||||
)
|
||||
|
||||
await self.user_repo.commit()
|
||||
logger.info(f"创建 OAuth2 用户: {user.username} (provider={OAUTH_PROVIDER_LINUXDO})")
|
||||
|
||||
return user, True
|
||||
|
||||
def _create_tokens(self, user: User) -> TokenResponse:
|
||||
"""
|
||||
为用户创建 JWT 令牌
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
|
||||
Returns:
|
||||
令牌响应
|
||||
"""
|
||||
access_token = create_access_token(
|
||||
subject=user.id,
|
||||
extra_claims={
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
"oauth_provider": user.oauth_provider,
|
||||
},
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(subject=user.id)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
570
app/services/redeem_code.py
Normal file
570
app/services/redeem_code.py
Normal file
@@ -0,0 +1,570 @@
|
||||
"""
|
||||
兑换码服务
|
||||
|
||||
处理兑换码相关的业务逻辑。
|
||||
|
||||
设计说明:
|
||||
- 兑换操作使用行级锁确保原子性
|
||||
- 支持批量生成和导入导出
|
||||
- 记录完整的使用日志
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import (
|
||||
AppException,
|
||||
ResourceNotFoundError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.models.redeem_code import (
|
||||
RedeemCode,
|
||||
RedeemCodeBatch,
|
||||
RedeemCodeUsageLog,
|
||||
RedeemCodeStatus,
|
||||
generate_redeem_code,
|
||||
)
|
||||
from app.models.balance import TransactionType
|
||||
from app.repositories.redeem_code import (
|
||||
RedeemCodeRepository,
|
||||
RedeemCodeBatchRepository,
|
||||
RedeemCodeUsageLogRepository,
|
||||
)
|
||||
from app.services.balance import BalanceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedeemCodeNotFoundError(AppException):
|
||||
"""兑换码不存在"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码不存在",
|
||||
"REDEEM_CODE_NOT_FOUND",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeInvalidError(AppException):
|
||||
"""兑换码无效"""
|
||||
|
||||
def __init__(self, code: str, reason: str):
|
||||
super().__init__(
|
||||
f"兑换码无效: {reason}",
|
||||
"REDEEM_CODE_INVALID",
|
||||
{"code": code, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeExpiredError(AppException):
|
||||
"""兑换码已过期"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已过期",
|
||||
"REDEEM_CODE_EXPIRED",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeUsedError(AppException):
|
||||
"""兑换码已使用"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已使用",
|
||||
"REDEEM_CODE_USED",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeDisabledError(AppException):
|
||||
"""兑换码已禁用"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已禁用",
|
||||
"REDEEM_CODE_DISABLED",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeService:
|
||||
"""兑换码服务"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化兑换码服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.code_repo = RedeemCodeRepository(session)
|
||||
self.batch_repo = RedeemCodeBatchRepository(session)
|
||||
self.log_repo = RedeemCodeUsageLogRepository(session)
|
||||
self.balance_service = BalanceService(session)
|
||||
|
||||
# ============================================================
|
||||
# 用户兑换
|
||||
# ============================================================
|
||||
|
||||
async def redeem(
|
||||
self,
|
||||
user_id: str,
|
||||
code: str,
|
||||
*,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
用户兑换余额
|
||||
|
||||
使用行级锁确保原子性,防止并发兑换。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
code: 兑换码
|
||||
ip_address: 客户端 IP
|
||||
user_agent: User Agent
|
||||
|
||||
Returns:
|
||||
兑换结果
|
||||
|
||||
Raises:
|
||||
RedeemCodeNotFoundError: 兑换码不存在
|
||||
RedeemCodeInvalidError: 兑换码无效
|
||||
"""
|
||||
# 标准化兑换码
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
|
||||
# 获取兑换码并加锁
|
||||
redeem_code = await self.code_repo.get_by_code_for_update(normalized_code)
|
||||
|
||||
if not redeem_code:
|
||||
raise RedeemCodeNotFoundError(normalized_code)
|
||||
|
||||
# 验证兑换码状态
|
||||
self._validate_redeem_code(redeem_code)
|
||||
|
||||
# 获取用户当前余额
|
||||
balance = await self.balance_service.get_balance(user_id)
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行充值
|
||||
transaction = await self.balance_service.recharge(
|
||||
user_id,
|
||||
redeem_code.face_value,
|
||||
reference_type="redeem_code",
|
||||
reference_id=redeem_code.id,
|
||||
description=f"兑换码充值: {redeem_code.code}",
|
||||
)
|
||||
|
||||
# 标记兑换码已使用
|
||||
await self.code_repo.mark_as_used(redeem_code, user_id)
|
||||
|
||||
# 更新批次统计
|
||||
if redeem_code.batch_id:
|
||||
await self.batch_repo.increment_used_count(redeem_code.batch_id)
|
||||
|
||||
# 记录使用日志
|
||||
await self.log_repo.create(
|
||||
redeem_code_id=redeem_code.id,
|
||||
user_id=user_id,
|
||||
transaction_id=transaction.id,
|
||||
code_snapshot=redeem_code.code,
|
||||
face_value=redeem_code.face_value,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 兑换成功: {redeem_code.code}, "
|
||||
f"面值 {redeem_code.face_value} 单位"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "兑换成功",
|
||||
"face_value": f"{redeem_code.face_value / 1000:.2f}",
|
||||
"balance_before": f"{balance_before / 1000:.2f}",
|
||||
"balance_after": f"{transaction.balance_after / 1000:.2f}",
|
||||
}
|
||||
|
||||
def _validate_redeem_code(self, code: RedeemCode) -> None:
|
||||
"""验证兑换码有效性"""
|
||||
if code.status == RedeemCodeStatus.DISABLED:
|
||||
raise RedeemCodeDisabledError(code.code)
|
||||
|
||||
if code.status == RedeemCodeStatus.USED or code.used_count >= code.max_uses:
|
||||
raise RedeemCodeUsedError(code.code)
|
||||
|
||||
if code.expires_at and code.expires_at < datetime.now(timezone.utc):
|
||||
raise RedeemCodeExpiredError(code.code)
|
||||
|
||||
# ============================================================
|
||||
# 管理员:批量生成
|
||||
# ============================================================
|
||||
|
||||
async def create_batch(
|
||||
self,
|
||||
name: str,
|
||||
face_value_units: int,
|
||||
count: int,
|
||||
*,
|
||||
created_by: str,
|
||||
description: str | None = None,
|
||||
max_uses: int = 1,
|
||||
expires_at: datetime | None = None,
|
||||
) -> RedeemCodeBatch:
|
||||
"""
|
||||
创建兑换码批次
|
||||
|
||||
批量生成指定数量的兑换码。
|
||||
|
||||
Args:
|
||||
name: 批次名称
|
||||
face_value_units: 面值(单位额度)
|
||||
count: 生成数量
|
||||
created_by: 创建者 ID
|
||||
description: 批次描述
|
||||
max_uses: 每个兑换码最大使用次数
|
||||
expires_at: 过期时间
|
||||
|
||||
Returns:
|
||||
创建的批次
|
||||
"""
|
||||
if face_value_units <= 0:
|
||||
raise ValidationError("面值必须大于 0")
|
||||
if count <= 0 or count > 10000:
|
||||
raise ValidationError("数量必须在 1-10000 之间")
|
||||
|
||||
# 创建批次
|
||||
batch = await self.batch_repo.create(
|
||||
name=name,
|
||||
description=description,
|
||||
face_value=face_value_units,
|
||||
total_count=count,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
# 批量生成兑换码
|
||||
codes_data = []
|
||||
generated_codes = set()
|
||||
|
||||
while len(codes_data) < count:
|
||||
new_code = generate_redeem_code()
|
||||
if new_code not in generated_codes:
|
||||
generated_codes.add(new_code)
|
||||
codes_data.append({
|
||||
"code": new_code,
|
||||
"batch_id": batch.id,
|
||||
"face_value": face_value_units,
|
||||
"max_uses": max_uses,
|
||||
"expires_at": expires_at,
|
||||
"created_by": created_by,
|
||||
})
|
||||
|
||||
await self.code_repo.bulk_create(codes_data)
|
||||
await self.batch_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"管理员 {created_by} 创建批次 '{name}': "
|
||||
f"{count} 个兑换码, 面值 {face_value_units} 单位"
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
# ============================================================
|
||||
# 管理员:导入
|
||||
# ============================================================
|
||||
|
||||
async def import_codes(
|
||||
self,
|
||||
codes: list[dict[str, Any]],
|
||||
*,
|
||||
created_by: str,
|
||||
batch_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
导入兑换码
|
||||
|
||||
Args:
|
||||
codes: 兑换码数据列表
|
||||
created_by: 创建者 ID
|
||||
batch_name: 批次名称(可选)
|
||||
|
||||
Returns:
|
||||
导入结果
|
||||
"""
|
||||
batch_id = None
|
||||
|
||||
# 创建批次(如果指定)
|
||||
if batch_name:
|
||||
# 计算总面值用于批次记录
|
||||
total_face_value = sum(c.get("face_value_units", 0) for c in codes)
|
||||
batch = await self.batch_repo.create(
|
||||
name=batch_name,
|
||||
description="导入批次",
|
||||
face_value=total_face_value // len(codes) if codes else 0,
|
||||
total_count=len(codes),
|
||||
created_by=created_by,
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
success_count = 0
|
||||
failed_codes = []
|
||||
|
||||
for code_data in codes:
|
||||
try:
|
||||
# 检查兑换码是否已存在
|
||||
existing = await self.code_repo.get_by_code(code_data["code"])
|
||||
if existing:
|
||||
failed_codes.append(code_data["code"])
|
||||
continue
|
||||
|
||||
# 创建兑换码
|
||||
await self.code_repo.create(
|
||||
code=code_data["code"].strip().upper(),
|
||||
batch_id=batch_id,
|
||||
face_value=code_data["face_value_units"],
|
||||
max_uses=code_data.get("max_uses", 1),
|
||||
expires_at=code_data.get("expires_at"),
|
||||
remark=code_data.get("remark"),
|
||||
created_by=created_by,
|
||||
)
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"导入兑换码失败: {code_data.get('code')}, {e}")
|
||||
failed_codes.append(code_data.get("code", "unknown"))
|
||||
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"管理员 {created_by} 导入兑换码: "
|
||||
f"成功 {success_count}, 失败 {len(failed_codes)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success_count": success_count,
|
||||
"failed_count": len(failed_codes),
|
||||
"failed_codes": failed_codes,
|
||||
"batch_id": batch_id,
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 管理员:导出
|
||||
# ============================================================
|
||||
|
||||
async def export_codes(
|
||||
self,
|
||||
*,
|
||||
batch_id: str | None = None,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
limit: int = 10000,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
导出兑换码
|
||||
|
||||
Args:
|
||||
batch_id: 批次 ID 过滤
|
||||
status: 状态过滤
|
||||
limit: 最大导出数量
|
||||
|
||||
Returns:
|
||||
兑换码数据列表
|
||||
"""
|
||||
codes = await self.code_repo.get_all_with_filters(
|
||||
batch_id=batch_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
result = []
|
||||
for code in codes:
|
||||
result.append({
|
||||
"code": code.code,
|
||||
"face_value": f"{code.face_value / 1000:.2f}",
|
||||
"status": code.status.value,
|
||||
"max_uses": code.max_uses,
|
||||
"used_count": code.used_count,
|
||||
"expires_at": code.expires_at.isoformat() if code.expires_at else None,
|
||||
"created_at": code.created_at.isoformat(),
|
||||
"used_at": code.used_at.isoformat() if code.used_at else None,
|
||||
"used_by": code.used_by,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
# ============================================================
|
||||
# 管理员:查询
|
||||
# ============================================================
|
||||
|
||||
async def get_codes(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
batch_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> tuple[list[RedeemCode], int]:
|
||||
"""
|
||||
获取兑换码列表
|
||||
|
||||
Returns:
|
||||
(兑换码列表, 总数)
|
||||
"""
|
||||
codes = await self.code_repo.get_all_with_filters(
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
status=status,
|
||||
batch_id=batch_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
total = await self.code_repo.count_with_filters(
|
||||
status=status,
|
||||
batch_id=batch_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
return codes, total
|
||||
|
||||
async def get_code_detail(self, code_id: str) -> RedeemCode:
|
||||
"""
|
||||
获取兑换码详情
|
||||
|
||||
Args:
|
||||
code_id: 兑换码 ID
|
||||
|
||||
Returns:
|
||||
兑换码记录
|
||||
"""
|
||||
code = await self.code_repo.get_by_id(code_id)
|
||||
if not code:
|
||||
raise ResourceNotFoundError("兑换码不存在", "redeem_code", code_id)
|
||||
return code
|
||||
|
||||
async def disable_code(self, code_id: str) -> RedeemCode:
|
||||
"""
|
||||
禁用兑换码
|
||||
|
||||
Args:
|
||||
code_id: 兑换码 ID
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
code = await self.get_code_detail(code_id)
|
||||
code = await self.code_repo.disable_code(code)
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(f"兑换码已禁用: {code.code}")
|
||||
return code
|
||||
|
||||
async def enable_code(self, code_id: str) -> RedeemCode:
|
||||
"""
|
||||
启用兑换码
|
||||
|
||||
Args:
|
||||
code_id: 兑换码 ID
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
code = await self.get_code_detail(code_id)
|
||||
code = await self.code_repo.enable_code(code)
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(f"兑换码已启用: {code.code}")
|
||||
return code
|
||||
|
||||
# ============================================================
|
||||
# 管理员:批次管理
|
||||
# ============================================================
|
||||
|
||||
async def get_batches(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> tuple[list[RedeemCodeBatch], int]:
|
||||
"""
|
||||
获取批次列表
|
||||
|
||||
Returns:
|
||||
(批次列表, 总数)
|
||||
"""
|
||||
batches = await self.batch_repo.get_all_batches(
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
total = await self.batch_repo.count()
|
||||
return batches, total
|
||||
|
||||
async def get_batch_detail(self, batch_id: str) -> RedeemCodeBatch:
|
||||
"""
|
||||
获取批次详情
|
||||
|
||||
Args:
|
||||
batch_id: 批次 ID
|
||||
|
||||
Returns:
|
||||
批次记录
|
||||
"""
|
||||
batch = await self.batch_repo.get_by_id(batch_id)
|
||||
if not batch:
|
||||
raise ResourceNotFoundError("批次不存在", "batch", batch_id)
|
||||
return batch
|
||||
|
||||
# ============================================================
|
||||
# 管理员:使用日志
|
||||
# ============================================================
|
||||
|
||||
async def get_usage_logs(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
redeem_code_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> tuple[list[RedeemCodeUsageLog], int]:
|
||||
"""
|
||||
获取使用日志
|
||||
|
||||
Returns:
|
||||
(日志列表, 总数)
|
||||
"""
|
||||
logs = await self.log_repo.get_all_with_filters(
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
redeem_code_id=redeem_code_id,
|
||||
user_id=user_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
total = await self.log_repo.count_with_filters(
|
||||
redeem_code_id=redeem_code_id,
|
||||
user_id=user_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
return logs, total
|
||||
|
||||
175
app/services/user.py
Normal file
175
app/services/user.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
用户服务
|
||||
|
||||
处理用户相关的业务逻辑。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import (
|
||||
UserAlreadyExistsError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from app.core.security import hash_password
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.user import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户服务"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化用户服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
|
||||
async def create_user(self, user_data: UserCreate) -> User:
|
||||
"""
|
||||
创建新用户
|
||||
|
||||
Args:
|
||||
user_data: 用户注册数据
|
||||
|
||||
Returns:
|
||||
新创建的用户
|
||||
|
||||
Raises:
|
||||
UserAlreadyExistsError: 用户名或邮箱已存在
|
||||
"""
|
||||
# 检查用户名是否已存在
|
||||
if await self.user_repo.exists_by_username(user_data.username):
|
||||
raise UserAlreadyExistsError("用户名")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if user_data.email and await self.user_repo.exists_by_email(user_data.email):
|
||||
raise UserAlreadyExistsError("邮箱")
|
||||
|
||||
# 创建用户
|
||||
user = await self.user_repo.create(
|
||||
username=user_data.username.lower(),
|
||||
email=user_data.email.lower() if user_data.email else None,
|
||||
nickname=user_data.nickname,
|
||||
hashed_password=hash_password(user_data.password),
|
||||
)
|
||||
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User:
|
||||
"""
|
||||
通过 ID 获取用户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: 用户不存在
|
||||
"""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
return user
|
||||
|
||||
async def get_user_by_username(self, username: str) -> User | None:
|
||||
"""
|
||||
通过用户名获取用户
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
return await self.user_repo.get_by_username(username)
|
||||
|
||||
async def update_user(
|
||||
self,
|
||||
user_id: str,
|
||||
update_data: UserUpdate,
|
||||
) -> User:
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
update_data: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的用户
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: 用户不存在
|
||||
UserAlreadyExistsError: 邮箱已被使用
|
||||
"""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
|
||||
# 检查邮箱是否被其他用户使用
|
||||
if update_data.email:
|
||||
existing_user = await self.user_repo.get_by_email(update_data.email)
|
||||
if existing_user and existing_user.id != user_id:
|
||||
raise UserAlreadyExistsError("邮箱")
|
||||
|
||||
# 准备更新数据
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
if update_dict.get("email"):
|
||||
update_dict["email"] = update_dict["email"].lower()
|
||||
|
||||
# 更新用户
|
||||
user = await self.user_repo.update(user, **update_dict)
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
async def update_last_login(self, user: User) -> None:
|
||||
"""
|
||||
更新用户最后登录时间
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
"""
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
async def deactivate_user(self, user_id: str) -> User:
|
||||
"""
|
||||
禁用用户账户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
更新后的用户
|
||||
"""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
user = await self.user_repo.update(user, is_active=False)
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
async def activate_user(self, user_id: str) -> User:
|
||||
"""
|
||||
激活用户账户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
更新后的用户
|
||||
"""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
user = await self.user_repo.update(user, is_active=True)
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user