提供基本前后端骨架

This commit is contained in:
hisatri
2026-01-06 23:49:23 +08:00
parent 84d4ccc226
commit 06f8176e23
89 changed files with 19293 additions and 2 deletions

6
app/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
"""
SatoNano - 现代化用户认证系统
"""
__version__ = "0.1.0"

2
app/api/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""API 模块"""

108
app/api/deps.py Normal file
View File

@@ -0,0 +1,108 @@
"""
API 依赖注入
定义 FastAPI 依赖项。
"""
from typing import Annotated
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import (
TokenError,
UserDisabledError,
)
from app.database import get_db
from app.models.user import User
from app.services.auth import AuthService
# HTTP Bearer 认证方案
security = HTTPBearer(auto_error=False)
async def get_auth_service(
session: Annotated[AsyncSession, Depends(get_db)],
) -> AuthService:
"""获取认证服务实例"""
return AuthService(session)
async def get_current_user(
credentials: Annotated[
HTTPAuthorizationCredentials | None,
Depends(security),
],
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> User:
"""
获取当前认证用户
从请求头的 Bearer Token 中解析用户信息。
Raises:
HTTPException 401: 未认证
HTTPException 403: 用户被禁用
"""
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证信息",
headers={"WWW-Authenticate": "Bearer"},
)
try:
user = await auth_service.get_current_user(credentials.credentials)
return user
except TokenError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=e.message,
headers={"WWW-Authenticate": "Bearer"},
)
except UserDisabledError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=e.message,
)
async def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
"""
获取当前活跃用户
确保用户处于激活状态。
"""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账户已被禁用",
)
return current_user
async def get_current_superuser(
current_user: Annotated[User, Depends(get_current_active_user)],
) -> User:
"""
获取当前超级管理员
确保用户具有超级管理员权限。
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限",
)
return current_user
# 类型别名,简化依赖注入声明
DbSession = Annotated[AsyncSession, Depends(get_db)]
CurrentUser = Annotated[User, Depends(get_current_user)]
ActiveUser = Annotated[User, Depends(get_current_active_user)]
SuperUser = Annotated[User, Depends(get_current_superuser)]

2
app/api/v1/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""API v1 模块"""

View File

@@ -0,0 +1,2 @@
"""API v1 端点"""

View File

@@ -0,0 +1,6 @@
"""管理员 API 端点"""
from app.api.v1.endpoints.admin import redeem_codes
__all__ = ["redeem_codes"]

View File

@@ -0,0 +1,613 @@
"""
管理员 - 兑换码管理 API
包括批量生成、导入导出、查询使用日志等接口。
"""
from datetime import datetime
from fastapi import APIRouter, HTTPException, Query, status
from fastapi.responses import JSONResponse
from app.api.deps import SuperUser, DbSession
from app.core.exceptions import ResourceNotFoundError, ValidationError
from app.schemas.base import APIResponse, PaginatedResponse
from app.schemas.balance import (
AdminAdjustmentRequest,
AdminBalanceResponse,
format_display,
)
from app.schemas.redeem_code import (
BatchCreateRequest,
BatchResponse,
BatchDetailResponse,
BatchListResponse,
RedeemCodeDetailResponse,
RedeemCodeListResponse,
BulkImportRequest,
BulkImportResponse,
ExportResponse,
ExportCodeItem,
UsageLogResponse,
UsageLogListResponse,
)
from app.services.balance import BalanceService, InsufficientBalanceError
from app.services.redeem_code import RedeemCodeService
from app.models.redeem_code import RedeemCodeStatus
router = APIRouter()
# ============================================================
# 批次管理
# ============================================================
@router.post(
"/batches",
response_model=APIResponse[BatchResponse],
summary="创建兑换码批次",
description="批量生成指定数量的兑换码",
)
async def create_batch(
request: BatchCreateRequest,
current_user: SuperUser,
session: DbSession,
) -> APIResponse[BatchResponse]:
"""
创建兑换码批次
批量生成指定数量的兑换码。
- **name**: 批次名称
- **face_value**: 面值(如 10.00
- **count**: 生成数量(最大 10000
- **max_uses**: 每个兑换码最大使用次数
- **expires_at**: 过期时间(可选)
"""
redeem_service = RedeemCodeService(session)
try:
batch = await redeem_service.create_batch(
name=request.name,
face_value_units=request.face_value_units,
count=request.count,
created_by=current_user.id,
description=request.description,
max_uses=request.max_uses,
expires_at=request.expires_at,
)
return APIResponse.ok(
data=BatchResponse(
id=batch.id,
name=batch.name,
description=batch.description,
face_value_units=batch.face_value,
total_count=batch.total_count,
used_count=batch.used_count,
created_by=batch.created_by,
created_at=batch.created_at,
),
message=f"成功创建批次,生成 {request.count} 个兑换码",
)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)
@router.get(
"/batches",
response_model=APIResponse[BatchListResponse],
summary="获取批次列表",
description="获取所有兑换码批次",
)
async def get_batches(
current_user: SuperUser,
session: DbSession,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
) -> APIResponse[BatchListResponse]:
"""获取批次列表"""
redeem_service = RedeemCodeService(session)
offset = (page - 1) * page_size
batches, total = await redeem_service.get_batches(
offset=offset,
limit=page_size,
)
items = [
BatchResponse(
id=b.id,
name=b.name,
description=b.description,
face_value_units=b.face_value,
total_count=b.total_count,
used_count=b.used_count,
created_by=b.created_by,
created_at=b.created_at,
)
for b in batches
]
return APIResponse.ok(
data=BatchListResponse.create(
items=items,
total=total,
page=page,
page_size=page_size,
),
)
@router.get(
"/batches/{batch_id}",
response_model=APIResponse[BatchDetailResponse],
summary="获取批次详情",
description="获取指定批次的详细信息",
)
async def get_batch_detail(
batch_id: str,
current_user: SuperUser,
session: DbSession,
) -> APIResponse[BatchDetailResponse]:
"""获取批次详情"""
redeem_service = RedeemCodeService(session)
try:
batch = await redeem_service.get_batch_detail(batch_id)
return APIResponse.ok(
data=BatchDetailResponse(
id=batch.id,
name=batch.name,
description=batch.description,
face_value_units=batch.face_value,
total_count=batch.total_count,
used_count=batch.used_count,
created_by=batch.created_by,
created_at=batch.created_at,
),
)
except ResourceNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=e.message,
)
# ============================================================
# 兑换码管理
# ============================================================
@router.get(
"/codes",
response_model=APIResponse[RedeemCodeListResponse],
summary="获取兑换码列表",
description="获取所有兑换码,支持多种过滤条件",
)
async def get_codes(
current_user: SuperUser,
session: DbSession,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
status_filter: RedeemCodeStatus | None = Query(None, alias="status"),
batch_id: str | None = Query(None),
code: str | None = Query(None, description="兑换码模糊搜索"),
created_after: datetime | None = Query(None),
created_before: datetime | None = Query(None),
) -> APIResponse[RedeemCodeListResponse]:
"""
获取兑换码列表
支持过滤:
- **status**: 状态active/used/disabled/expired
- **batch_id**: 批次 ID
- **code**: 兑换码模糊搜索
- **created_after/created_before**: 创建时间范围
"""
redeem_service = RedeemCodeService(session)
offset = (page - 1) * page_size
codes, total = await redeem_service.get_codes(
offset=offset,
limit=page_size,
status=status_filter,
batch_id=batch_id,
code_like=code,
created_after=created_after,
created_before=created_before,
)
items = [
RedeemCodeDetailResponse(
id=c.id,
code=c.code,
face_value_units=c.face_value,
status=c.status,
max_uses=c.max_uses,
used_count=c.used_count,
expires_at=c.expires_at,
used_at=c.used_at,
created_at=c.created_at,
batch_id=c.batch_id,
batch_name=c.batch.name if c.batch else None,
remark=c.remark,
created_by=c.created_by,
used_by=c.used_by,
)
for c in codes
]
return APIResponse.ok(
data=RedeemCodeListResponse.create(
items=items,
total=total,
page=page,
page_size=page_size,
),
)
@router.get(
"/codes/{code_id}",
response_model=APIResponse[RedeemCodeDetailResponse],
summary="获取兑换码详情",
)
async def get_code_detail(
code_id: str,
current_user: SuperUser,
session: DbSession,
) -> APIResponse[RedeemCodeDetailResponse]:
"""获取兑换码详情"""
redeem_service = RedeemCodeService(session)
try:
code = await redeem_service.get_code_detail(code_id)
return APIResponse.ok(
data=RedeemCodeDetailResponse(
id=code.id,
code=code.code,
face_value_units=code.face_value,
status=code.status,
max_uses=code.max_uses,
used_count=code.used_count,
expires_at=code.expires_at,
used_at=code.used_at,
created_at=code.created_at,
batch_id=code.batch_id,
batch_name=code.batch.name if code.batch else None,
remark=code.remark,
created_by=code.created_by,
used_by=code.used_by,
),
)
except ResourceNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=e.message,
)
@router.post(
"/codes/{code_id}/disable",
response_model=APIResponse[RedeemCodeDetailResponse],
summary="禁用兑换码",
)
async def disable_code(
code_id: str,
current_user: SuperUser,
session: DbSession,
) -> APIResponse[RedeemCodeDetailResponse]:
"""禁用指定兑换码"""
redeem_service = RedeemCodeService(session)
try:
code = await redeem_service.disable_code(code_id)
return APIResponse.ok(
data=RedeemCodeDetailResponse(
id=code.id,
code=code.code,
face_value_units=code.face_value,
status=code.status,
max_uses=code.max_uses,
used_count=code.used_count,
expires_at=code.expires_at,
used_at=code.used_at,
created_at=code.created_at,
batch_id=code.batch_id,
remark=code.remark,
created_by=code.created_by,
used_by=code.used_by,
),
message="兑换码已禁用",
)
except ResourceNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=e.message,
)
@router.post(
"/codes/{code_id}/enable",
response_model=APIResponse[RedeemCodeDetailResponse],
summary="启用兑换码",
)
async def enable_code(
code_id: str,
current_user: SuperUser,
session: DbSession,
) -> APIResponse[RedeemCodeDetailResponse]:
"""重新启用指定兑换码"""
redeem_service = RedeemCodeService(session)
try:
code = await redeem_service.enable_code(code_id)
return APIResponse.ok(
data=RedeemCodeDetailResponse(
id=code.id,
code=code.code,
face_value_units=code.face_value,
status=code.status,
max_uses=code.max_uses,
used_count=code.used_count,
expires_at=code.expires_at,
used_at=code.used_at,
created_at=code.created_at,
batch_id=code.batch_id,
remark=code.remark,
created_by=code.created_by,
used_by=code.used_by,
),
message="兑换码已启用",
)
except ResourceNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=e.message,
)
# ============================================================
# 导入导出
# ============================================================
@router.post(
"/codes/import",
response_model=APIResponse[BulkImportResponse],
summary="批量导入兑换码",
description="批量导入自定义兑换码",
)
async def import_codes(
request: BulkImportRequest,
current_user: SuperUser,
session: DbSession,
) -> APIResponse[BulkImportResponse]:
"""
批量导入兑换码
可以导入自定义格式的兑换码。
- **codes**: 兑换码列表
- **batch_name**: 批次名称(可选)
"""
redeem_service = RedeemCodeService(session)
# 转换请求数据
codes_data = [
{
"code": c.code,
"face_value_units": c.face_value_units,
"max_uses": c.max_uses,
"expires_at": c.expires_at,
"remark": c.remark,
}
for c in request.codes
]
result = await redeem_service.import_codes(
codes_data,
created_by=current_user.id,
batch_name=request.batch_name,
)
return APIResponse.ok(
data=BulkImportResponse(
success_count=result["success_count"],
failed_count=result["failed_count"],
failed_codes=result["failed_codes"],
batch_id=result["batch_id"],
),
message=f"导入完成:成功 {result['success_count']},失败 {result['failed_count']}",
)
@router.get(
"/codes/export",
response_model=APIResponse[ExportResponse],
summary="导出兑换码",
description="导出兑换码数据",
)
async def export_codes(
current_user: SuperUser,
session: DbSession,
batch_id: str | None = Query(None),
status_filter: RedeemCodeStatus | None = Query(None, alias="status"),
limit: int = Query(10000, ge=1, le=50000),
) -> APIResponse[ExportResponse]:
"""
导出兑换码
支持按批次或状态过滤。
"""
redeem_service = RedeemCodeService(session)
codes = await redeem_service.export_codes(
batch_id=batch_id,
status=status_filter,
limit=limit,
)
items = [
ExportCodeItem(
code=c["code"],
face_value=c["face_value"],
status=c["status"],
max_uses=c["max_uses"],
used_count=c["used_count"],
expires_at=c["expires_at"],
created_at=c["created_at"],
used_at=c["used_at"],
used_by=c["used_by"],
)
for c in codes
]
return APIResponse.ok(
data=ExportResponse(
total=len(items),
codes=items,
),
)
# ============================================================
# 使用日志
# ============================================================
@router.get(
"/usage-logs",
response_model=APIResponse[UsageLogListResponse],
summary="获取使用日志",
description="查询兑换码使用记录",
)
async def get_usage_logs(
current_user: SuperUser,
session: DbSession,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
redeem_code_id: str | None = Query(None),
user_id: str | None = Query(None),
code: str | None = Query(None, description="兑换码模糊搜索"),
created_after: datetime | None = Query(None),
created_before: datetime | None = Query(None),
) -> APIResponse[UsageLogListResponse]:
"""
获取使用日志
支持过滤:
- **redeem_code_id**: 兑换码 ID
- **user_id**: 用户 ID
- **code**: 兑换码模糊搜索
- **created_after/created_before**: 使用时间范围
"""
redeem_service = RedeemCodeService(session)
offset = (page - 1) * page_size
logs, total = await redeem_service.get_usage_logs(
offset=offset,
limit=page_size,
redeem_code_id=redeem_code_id,
user_id=user_id,
code_like=code,
created_after=created_after,
created_before=created_before,
)
items = [
UsageLogResponse(
id=log.id,
redeem_code_id=log.redeem_code_id,
code_snapshot=log.code_snapshot,
user_id=log.user_id,
username=log.user.username if log.user else None,
face_value=format_display(log.face_value),
ip_address=log.ip_address,
created_at=log.created_at,
)
for log in logs
]
return APIResponse.ok(
data=UsageLogListResponse.create(
items=items,
total=total,
page=page,
page_size=page_size,
),
)
# ============================================================
# 余额管理
# ============================================================
@router.post(
"/balance/adjust",
response_model=APIResponse[AdminBalanceResponse],
summary="调整用户余额",
description="管理员手动调整用户余额",
)
async def adjust_balance(
request: AdminAdjustmentRequest,
current_user: SuperUser,
session: DbSession,
) -> APIResponse[AdminBalanceResponse]:
"""
调整用户余额
- **user_id**: 目标用户 ID
- **amount**: 调整金额(正数增加,负数减少)
- **reason**: 调整原因
"""
balance_service = BalanceService(session)
try:
transaction = await balance_service.admin_adjust(
request.user_id,
request.amount_units,
operator_id=current_user.id,
reason=request.reason,
)
# 获取更新后的余额
balance = await balance_service.get_balance(request.user_id)
return APIResponse.ok(
data=AdminBalanceResponse(
user_id=balance.user_id,
username=balance.user.username if balance.user else "",
balance=format_display(balance.balance),
available_balance=format_display(balance.available_balance),
frozen_balance=format_display(balance.frozen_balance),
total_recharged=format_display(balance.total_recharged),
total_consumed=format_display(balance.total_consumed),
version=balance.version,
created_at=balance.created_at,
updated_at=balance.updated_at,
),
message=f"余额调整成功,变动 {request.amount:+.2f}",
)
except InsufficientBalanceError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)

View File

@@ -0,0 +1,191 @@
"""
认证相关 API
包括注册、登录、退出、刷新令牌、修改密码等接口。
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import ActiveUser, DbSession, get_auth_service
from app.core.exceptions import (
InvalidCredentialsError,
PasswordValidationError,
ResourceConflictError,
TokenError,
UserDisabledError,
)
from app.schemas.auth import (
LoginRequest,
PasswordChangeRequest,
RefreshTokenRequest,
TokenResponse,
)
from app.schemas.base import APIResponse
from app.schemas.user import UserCreate, UserResponse
from app.services.auth import AuthService
from app.services.user import UserService
router = APIRouter()
@router.post(
"/register",
response_model=APIResponse[UserResponse],
status_code=status.HTTP_201_CREATED,
summary="用户注册",
description="创建新用户账户",
)
async def register(
user_data: UserCreate,
session: DbSession,
) -> APIResponse[UserResponse]:
"""
用户注册接口
- **username**: 用户名字母开头3-32位
- **email**: 邮箱(可选)
- **password**: 密码8位以上需包含大小写字母和数字
- **nickname**: 昵称(可选)
"""
user_service = UserService(session)
try:
user = await user_service.create_user(user_data)
return APIResponse.ok(
data=UserResponse.model_validate(user),
message="注册成功",
)
except ResourceConflictError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=e.message,
)
@router.post(
"/login",
response_model=APIResponse[TokenResponse],
summary="用户登录",
description="使用用户名/邮箱和密码登录",
)
async def login(
login_data: LoginRequest,
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> APIResponse[TokenResponse]:
"""
用户登录接口
- **username**: 用户名或邮箱
- **password**: 密码
返回访问令牌和刷新令牌。
"""
try:
_, tokens = await auth_service.login(
username=login_data.username,
password=login_data.password,
)
return APIResponse.ok(data=tokens, message="登录成功")
except InvalidCredentialsError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=e.message,
headers={"WWW-Authenticate": "Bearer"},
)
except UserDisabledError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=e.message,
)
@router.post(
"/logout",
response_model=APIResponse[None],
summary="用户退出",
description="退出登录(客户端应删除本地令牌)",
)
async def logout(
current_user: ActiveUser,
) -> APIResponse[None]:
"""
用户退出接口
由于使用的是无状态 JWT服务端不存储令牌
因此退出登录主要由客户端删除本地存储的令牌实现。
如果需要实现令牌黑名单,可以在后续版本中添加。
"""
# 可以在这里添加令牌黑名单逻辑
# 或者记录退出日志
return APIResponse.ok(message="退出成功")
@router.post(
"/refresh",
response_model=APIResponse[TokenResponse],
summary="刷新令牌",
description="使用刷新令牌获取新的访问令牌",
)
async def refresh_token(
token_data: RefreshTokenRequest,
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> APIResponse[TokenResponse]:
"""
刷新令牌接口
使用刷新令牌获取新的访问令牌和刷新令牌。
"""
try:
tokens = await auth_service.refresh_tokens(token_data.refresh_token)
return APIResponse.ok(data=tokens, message="刷新成功")
except TokenError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=e.message,
headers={"WWW-Authenticate": "Bearer"},
)
except UserDisabledError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=e.message,
)
@router.post(
"/change-password",
response_model=APIResponse[None],
summary="修改密码",
description="修改当前用户密码",
)
async def change_password(
password_data: PasswordChangeRequest,
current_user: ActiveUser,
auth_service: Annotated[AuthService, Depends(get_auth_service)],
) -> APIResponse[None]:
"""
修改密码接口
- **current_password**: 当前密码
- **new_password**: 新密码
"""
try:
await auth_service.change_password(
user_id=current_user.id,
password_data=password_data,
)
return APIResponse.ok(message="密码修改成功")
except InvalidCredentialsError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)
except PasswordValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)

View File

@@ -0,0 +1,270 @@
"""
余额相关 API
包括余额查询、交易记录、兑换等接口。
"""
from fastapi import APIRouter, HTTPException, Query, Request, status
from app.api.deps import ActiveUser, DbSession
from app.core.exceptions import AppException, ValidationError
from app.schemas.base import APIResponse, PaginatedResponse
from app.schemas.balance import (
BalanceResponse,
TransactionResponse,
DeductionRequest,
DeductionResponse,
UNITS_PER_DISPLAY,
format_display,
)
from app.schemas.redeem_code import RedeemRequest, RedeemResponse
from app.services.balance import (
BalanceService,
InsufficientBalanceError,
DuplicateTransactionError,
)
from app.services.redeem_code import (
RedeemCodeService,
RedeemCodeNotFoundError,
RedeemCodeInvalidError,
RedeemCodeExpiredError,
RedeemCodeUsedError,
RedeemCodeDisabledError,
)
from app.models.balance import TransactionType
router = APIRouter()
# ============================================================
# 余额查询
# ============================================================
@router.get(
"",
response_model=APIResponse[BalanceResponse],
summary="获取当前用户余额",
description="获取当前登录用户的余额信息",
)
async def get_my_balance(
current_user: ActiveUser,
session: DbSession,
) -> APIResponse[BalanceResponse]:
"""
获取当前用户余额
返回:
- **balance**: 当前总余额
- **available_balance**: 可用余额(总余额 - 冻结)
- **frozen_balance**: 冻结余额
- **total_recharged**: 累计充值
- **total_consumed**: 累计消费
"""
balance_service = BalanceService(session)
balance = await balance_service.get_balance(current_user.id)
return APIResponse.ok(
data=BalanceResponse(
user_id=balance.user_id,
balance_units=balance.balance,
frozen_units=balance.frozen_balance,
total_recharged_units=balance.total_recharged,
total_consumed_units=balance.total_consumed,
),
)
# ============================================================
# 交易记录
# ============================================================
@router.get(
"/transactions",
response_model=APIResponse[PaginatedResponse[TransactionResponse]],
summary="获取交易记录",
description="获取当前用户的余额交易记录",
)
async def get_my_transactions(
current_user: ActiveUser,
session: DbSession,
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
transaction_type: TransactionType | None = Query(
None, description="交易类型过滤"
),
) -> APIResponse[PaginatedResponse[TransactionResponse]]:
"""
获取交易记录
支持按交易类型过滤:
- **recharge**: 充值
- **deduction**: 扣款
- **refund**: 退款
- **adjustment**: 管理员调整
"""
balance_service = BalanceService(session)
offset = (page - 1) * page_size
transactions, total = await balance_service.get_transactions(
current_user.id,
offset=offset,
limit=page_size,
transaction_type=transaction_type,
)
items = [
TransactionResponse(
id=t.id,
transaction_type=t.transaction_type,
status=t.status,
amount_units=t.amount,
balance_before_units=t.balance_before,
balance_after_units=t.balance_after,
reference_type=t.reference_type,
reference_id=t.reference_id,
description=t.description,
created_at=t.created_at,
)
for t in transactions
]
return APIResponse.ok(
data=PaginatedResponse.create(
items=items,
total=total,
page=page,
page_size=page_size,
),
)
# ============================================================
# 兑换码兑换
# ============================================================
@router.post(
"/redeem",
response_model=APIResponse[RedeemResponse],
summary="兑换余额",
description="使用兑换码充值余额",
)
async def redeem_code(
request: Request,
redeem_request: RedeemRequest,
current_user: ActiveUser,
session: DbSession,
) -> APIResponse[RedeemResponse]:
"""
使用兑换码充值余额
- **code**: 兑换码(格式如 XXXX-XXXX-XXXX-XXXX
返回兑换结果,包括充值金额和余额变化。
"""
redeem_service = RedeemCodeService(session)
# 获取客户端信息
ip_address = request.client.host if request.client else None
user_agent = request.headers.get("user-agent")
try:
result = await redeem_service.redeem(
current_user.id,
redeem_request.code,
ip_address=ip_address,
user_agent=user_agent,
)
return APIResponse.ok(
data=RedeemResponse(
success=True,
message="兑换成功",
face_value=result["face_value"],
balance_before=result["balance_before"],
balance_after=result["balance_after"],
),
message="兑换成功",
)
except RedeemCodeNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=e.message,
)
except (
RedeemCodeInvalidError,
RedeemCodeExpiredError,
RedeemCodeUsedError,
RedeemCodeDisabledError,
) as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)
# ============================================================
# 扣款(内部 API / 服务间调用)
# ============================================================
@router.post(
"/deduct",
response_model=APIResponse[DeductionResponse],
summary="扣款",
description="从当前用户余额中扣款(通常由其他服务调用)",
include_in_schema=True, # 可设为 False 隐藏此 API
)
async def deduct_balance(
deduction_request: DeductionRequest,
current_user: ActiveUser,
session: DbSession,
) -> APIResponse[DeductionResponse]:
"""
扣款
从当前用户余额中扣除指定金额。
- **amount**: 扣款金额(如 1.00
- **reference_type**: 关联业务类型(如 api_call
- **reference_id**: 关联业务 ID
- **description**: 交易描述
- **idempotency_key**: 幂等键(防止重复扣款)
"""
balance_service = BalanceService(session)
try:
transaction = await balance_service.deduct(
current_user.id,
deduction_request.amount_units,
reference_type=deduction_request.reference_type,
reference_id=deduction_request.reference_id,
description=deduction_request.description,
idempotency_key=deduction_request.idempotency_key,
)
return APIResponse.ok(
data=DeductionResponse(
transaction_id=transaction.id,
amount=format_display(abs(transaction.amount)),
balance_before=format_display(transaction.balance_before),
balance_after=format_display(transaction.balance_after),
),
message="扣款成功",
)
except InsufficientBalanceError as e:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=e.message,
)
except DuplicateTransactionError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=e.message,
)
except ValidationError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)

View File

@@ -0,0 +1,157 @@
"""
OAuth2 认证 API
提供 OAuth2 第三方登录功能。
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import DbSession
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.schemas.base import APIResponse
from app.schemas.oauth2 import OAuth2AuthorizeResponse, OAuth2LoginResponse
from app.services.oauth2 import OAuth2Service, OAuth2StateError
router = APIRouter()
def get_oauth2_service(session: DbSession) -> OAuth2Service:
"""获取 OAuth2 服务实例"""
return OAuth2Service(session)
def _get_redirect_uri(request: Request) -> str:
"""
获取 OAuth2 回调 URL
根据请求动态构建完整的回调 URL
"""
# 优先使用 X-Forwarded 头(反向代理场景)
scheme = request.headers.get("X-Forwarded-Proto", request.url.scheme)
host = request.headers.get("X-Forwarded-Host", request.url.netloc)
return f"{scheme}://{host}{settings.oauth2_callback_path}"
@router.get(
"/authorize",
response_model=APIResponse[OAuth2AuthorizeResponse],
summary="获取 OAuth2 授权 URL",
description="获取第三方登录授权页面 URL",
)
async def get_authorize_url(
request: Request,
session: DbSession,
) -> APIResponse[OAuth2AuthorizeResponse]:
"""
获取 OAuth2 授权 URL
返回授权页面 URL 和状态码,客户端应重定向用户到该 URL。
流程:
1. 前端调用此接口获取授权 URL
2. 前端将用户重定向到授权 URL
3. 用户在第三方平台完成授权
4. 第三方平台重定向回回调 URL携带 code 和 state
5. 前端调用 /callback 接口完成登录
"""
# 检查 OAuth2 是否已配置
if not settings.oauth2_client_id or not settings.oauth2_client_secret:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="OAuth2 登录未配置",
)
oauth2_service = OAuth2Service(session)
redirect_uri = _get_redirect_uri(request)
authorize_url, state = oauth2_service.generate_authorize_url(redirect_uri)
return APIResponse.ok(
data=OAuth2AuthorizeResponse(
authorize_url=authorize_url,
state=state,
),
message="请重定向到授权 URL",
)
@router.get(
"/callback",
response_model=APIResponse[OAuth2LoginResponse],
summary="OAuth2 回调",
description="处理 OAuth2 授权回调,完成登录",
)
async def oauth2_callback(
request: Request,
code: Annotated[str, Query(description="授权码")],
state: Annotated[str, Query(description="状态码")],
session: DbSession,
) -> APIResponse[OAuth2LoginResponse]:
"""
OAuth2 回调接口
处理第三方平台的授权回调:
1. 验证 state 防止 CSRF 攻击
2. 用 code 换取访问令牌
3. 获取用户信息
4. 创建或关联本地用户
5. 返回 JWT 令牌
"""
oauth2_service = OAuth2Service(session)
redirect_uri = _get_redirect_uri(request)
try:
user, tokens, is_new_user = await oauth2_service.authenticate(
code=code,
state=state,
redirect_uri=redirect_uri,
)
return APIResponse.ok(
data=OAuth2LoginResponse(
access_token=tokens.access_token,
refresh_token=tokens.refresh_token,
token_type=tokens.token_type,
expires_in=tokens.expires_in,
is_new_user=is_new_user,
),
message="登录成功" if not is_new_user else "注册并登录成功",
)
except OAuth2StateError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)
except AuthenticationError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=e.message,
)
@router.post(
"/callback",
response_model=APIResponse[OAuth2LoginResponse],
summary="OAuth2 回调 (POST)",
description="处理 OAuth2 授权回调POST 方式)",
)
async def oauth2_callback_post(
request: Request,
code: Annotated[str, Query(description="授权码")],
state: Annotated[str, Query(description="状态码")],
session: DbSession,
) -> APIResponse[OAuth2LoginResponse]:
"""
OAuth2 回调接口POST 方式)
某些场景下前端可能使用 POST 方式调用回调。
逻辑与 GET 方式相同。
"""
return await oauth2_callback(request, code, state, session)

View File

@@ -0,0 +1,103 @@
"""
用户相关 API
包括获取用户信息、更新用户资料等接口。
"""
from fastapi import APIRouter, HTTPException, status
from app.api.deps import ActiveUser, DbSession
from app.core.exceptions import ResourceConflictError, UserNotFoundError
from app.schemas.base import APIResponse
from app.schemas.user import UserResponse, UserUpdate
from app.services.user import UserService
router = APIRouter()
@router.get(
"/me",
response_model=APIResponse[UserResponse],
summary="获取当前用户信息",
description="获取当前登录用户的详细信息",
)
async def get_current_user_info(
current_user: ActiveUser,
) -> APIResponse[UserResponse]:
"""
获取当前用户信息
返回当前登录用户的完整信息。
"""
return APIResponse.ok(
data=UserResponse.model_validate(current_user),
)
@router.patch(
"/me",
response_model=APIResponse[UserResponse],
summary="更新当前用户信息",
description="更新当前登录用户的资料",
)
async def update_current_user(
update_data: UserUpdate,
current_user: ActiveUser,
session: DbSession,
) -> APIResponse[UserResponse]:
"""
更新当前用户信息
支持更新:
- **nickname**: 昵称
- **email**: 邮箱
- **avatar_url**: 头像 URL
- **bio**: 个人简介
"""
user_service = UserService(session)
try:
user = await user_service.update_user(
user_id=current_user.id,
update_data=update_data,
)
return APIResponse.ok(
data=UserResponse.model_validate(user),
message="更新成功",
)
except ResourceConflictError as e:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=e.message,
)
@router.get(
"/{user_id}",
response_model=APIResponse[UserResponse],
summary="获取指定用户信息",
description="获取指定用户的公开信息",
)
async def get_user_by_id(
user_id: str,
current_user: ActiveUser,
session: DbSession,
) -> APIResponse[UserResponse]:
"""
获取指定用户信息
- **user_id**: 用户 ID
"""
user_service = UserService(session)
try:
user = await user_service.get_user_by_id(user_id)
return APIResponse.ok(
data=UserResponse.model_validate(user),
)
except UserNotFoundError:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在",
)

45
app/api/v1/router.py Normal file
View File

@@ -0,0 +1,45 @@
"""
API v1 路由聚合
汇总所有 v1 版本的 API 路由。
"""
from fastapi import APIRouter
from app.api.v1.endpoints import auth, oauth2, users, balance
from app.api.v1.endpoints.admin import redeem_codes as admin_redeem_codes
api_router = APIRouter()
# 注册子路由
api_router.include_router(
auth.router,
prefix="/auth",
tags=["认证"],
)
api_router.include_router(
oauth2.router,
prefix="/auth/oauth2",
tags=["OAuth2 登录"],
)
api_router.include_router(
users.router,
prefix="/users",
tags=["用户"],
)
api_router.include_router(
balance.router,
prefix="/balance",
tags=["余额"],
)
# 管理员路由
api_router.include_router(
admin_redeem_codes.router,
prefix="/admin/redeem-codes",
tags=["管理员 - 兑换码"],
)

29
app/core/__init__.py Normal file
View File

@@ -0,0 +1,29 @@
"""核心功能模块"""
from .config import Settings, get_settings, settings
from .config_loader import (
ConfigLoader,
ConfigSource,
ConfigSourceError,
EnvConfigSource,
YamlConfigSource,
config_loader,
create_config_loader,
get_config_loader,
)
__all__ = [
# 配置系统
"Settings",
"get_settings",
"settings",
# 配置加载器
"ConfigLoader",
"ConfigSource",
"ConfigSourceError",
"EnvConfigSource",
"YamlConfigSource",
"config_loader",
"create_config_loader",
"get_config_loader",
]

202
app/core/config.py Normal file
View File

@@ -0,0 +1,202 @@
"""
应用配置管理
使用 Pydantic Settings 进行类型安全的配置管理,
集成 ConfigLoader 支持环境变量和 YAML 配置文件。
配置优先级(从高到低):
1. 环境变量(前缀 SATONANO_
2. config.yaml 文件
3. .env 文件
4. 默认值
"""
from functools import lru_cache
from pathlib import Path
from typing import Any, Literal
from pydantic import Field, computed_field
from pydantic_settings import BaseSettings, SettingsConfigDict
from .config_loader import create_config_loader
def _yaml_settings_source() -> dict[str, Any]:
"""
YAML 配置源工厂函数
为 Pydantic Settings 提供 YAML 配置数据
"""
# 查找配置文件路径
config_paths = [
Path("config.yaml"),
Path("config.yml"),
Path(__file__).parent.parent.parent / "config.yaml",
Path(__file__).parent.parent.parent / "config.yml",
]
yaml_path = None
for path in config_paths:
if path.exists():
yaml_path = path
break
if yaml_path is None:
return {}
# 使用 ConfigLoader 加载配置(不带环境变量前缀,让 pydantic 处理)
loader = create_config_loader(
yaml_path=yaml_path,
env_prefix="", # 不读取环境变量,由 pydantic-settings 处理
yaml_required=False,
)
return loader.load()
class Settings(BaseSettings):
"""应用配置"""
model_config = SettingsConfigDict(
env_prefix="SATONANO_",
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
extra="ignore",
)
# 应用基础配置
app_name: str = "SatoNano"
app_version: str = "0.1.0"
debug: bool = False
environment: Literal["development", "staging", "production"] = "development"
# API 配置
api_v1_prefix: str = "/api/v1"
# 安全配置
secret_key: str = Field(
default="CHANGE-THIS-SECRET-KEY-IN-PRODUCTION",
description="JWT 签名密钥,生产环境必须更改",
)
algorithm: str = "HS256"
access_token_expire_minutes: int = 30
refresh_token_expire_days: int = 7
# 数据库配置
database_type: Literal["sqlite", "mysql", "postgresql"] = "sqlite"
database_host: str = "localhost"
database_port: int = 3306
database_name: str = "satonano"
database_username: str = ""
database_password: str = ""
database_sqlite_path: str = "./satonano.db"
database_echo: bool = False
# OAuth2 配置 (Linux.do)
oauth2_client_id: str = ""
oauth2_client_secret: str = ""
oauth2_callback_path: str = "/oauth2/callback"
# 首选端点
oauth2_authorize_endpoint: str = "https://connect.linux.do/oauth2/authorize"
oauth2_token_endpoint: str = "https://connect.linux.do/oauth2/token"
oauth2_user_info_endpoint: str = "https://connect.linux.do/api/user"
# 备用端点
oauth2_authorize_endpoint_reserve: str = "https://connect.linuxdo.org/oauth2/authorize"
oauth2_token_endpoint_reserve: str = "https://connect.linuxdo.org/oauth2/token"
oauth2_user_info_endpoint_reserve: str = "https://connect.linuxdo.org/api/user"
oauth2_request_timeout: int = 10 # 请求超时时间(秒)
# 密码策略
password_min_length: int = 8
password_max_length: int = 128
password_require_uppercase: bool = True
password_require_lowercase: bool = True
password_require_digit: bool = True
password_require_special: bool = False
# 用户名策略
username_min_length: int = 3
username_max_length: int = 32
# 前端静态文件配置
frontend_static_path: str = "./frontend/out"
@computed_field
@property
def database_url(self) -> str:
"""
动态构建数据库连接 URL
根据 database_type 自动生成对应的连接字符串
"""
if self.database_type == "sqlite":
return f"sqlite+aiosqlite:///{self.database_sqlite_path}"
if self.database_type == "mysql":
return (
f"mysql+aiomysql://{self.database_username}:{self.database_password}"
f"@{self.database_host}:{self.database_port}/{self.database_name}"
)
if self.database_type == "postgresql":
return (
f"postgresql+asyncpg://{self.database_username}:{self.database_password}"
f"@{self.database_host}:{self.database_port}/{self.database_name}"
)
return f"sqlite+aiosqlite:///{self.database_sqlite_path}"
@computed_field
@property
def is_production(self) -> bool:
"""是否为生产环境"""
return self.environment == "production"
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: Any,
env_settings: Any,
dotenv_settings: Any,
file_secret_settings: Any,
) -> tuple[Any, ...]:
"""
自定义配置源优先级
优先级顺序(从高到低):
1. 初始化参数
2. 环境变量
3. YAML 文件
4. .env 文件
5. 文件密钥
"""
class YamlConfigSettingsSource:
"""YAML 配置源适配器"""
def __init__(self, settings_cls: type[BaseSettings]) -> None:
self.settings_cls = settings_cls
self._yaml_data: dict[str, Any] | None = None
def __call__(self) -> dict[str, Any]:
if self._yaml_data is None:
self._yaml_data = _yaml_settings_source()
return self._yaml_data
return (
init_settings,
env_settings,
YamlConfigSettingsSource(settings_cls),
dotenv_settings,
file_secret_settings,
)
@lru_cache
def get_settings() -> Settings:
"""获取配置单例"""
return Settings()
settings = get_settings()

387
app/core/config_loader.py Normal file
View File

@@ -0,0 +1,387 @@
"""
配置加载器模块
提供统一的配置加载机制,支持多数据源配置合并:
- 环境变量(优先级最高)
- YAML 配置文件
- 默认值(优先级最低)
设计原则:
- 单一职责:仅负责配置的加载和合并
- 依赖倒置:通过抽象接口解耦配置源
- 开闭原则:易于扩展新的配置源
"""
from __future__ import annotations
import os
from abc import ABC, abstractmethod
from functools import lru_cache
from pathlib import Path
from typing import Any, TypeVar
import yaml
T = TypeVar("T")
class ConfigSourceError(Exception):
"""配置源加载错误"""
def __init__(self, source: str, message: str) -> None:
self.source = source
self.message = message
super().__init__(f"[{source}] {message}")
class ConfigSource(ABC):
"""配置源抽象基类"""
@property
@abstractmethod
def name(self) -> str:
"""配置源名称"""
...
@property
@abstractmethod
def priority(self) -> int:
"""优先级,数值越大优先级越高"""
...
@abstractmethod
def load(self) -> dict[str, Any]:
"""加载配置,返回扁平化的键值对"""
...
class EnvConfigSource(ConfigSource):
"""环境变量配置源"""
def __init__(self, prefix: str = "") -> None:
"""
初始化环境变量配置源
Args:
prefix: 环境变量前缀,用于过滤相关配置
例如 prefix="SATONANO_" 时只读取以此为前缀的变量
"""
self._prefix = prefix.upper()
@property
def name(self) -> str:
return "environment"
@property
def priority(self) -> int:
return 100 # 最高优先级
def load(self) -> dict[str, Any]:
"""
加载环境变量
Returns:
配置字典,键名转换为小写并移除前缀
"""
config: dict[str, Any] = {}
prefix_len = len(self._prefix)
for key, value in os.environ.items():
if self._prefix and not key.upper().startswith(self._prefix):
continue
# 移除前缀并转为小写
config_key = key[prefix_len:].lower() if self._prefix else key.lower()
config[config_key] = self._parse_value(value)
return config
@staticmethod
def _parse_value(value: str) -> Any:
"""
解析环境变量值,自动转换类型
Args:
value: 原始字符串值
Returns:
转换后的值bool/int/float/str
"""
# 布尔值
if value.lower() in ("true", "yes", "1", "on"):
return True
if value.lower() in ("false", "no", "0", "off"):
return False
# 整数
try:
return int(value)
except ValueError:
pass
# 浮点数
try:
return float(value)
except ValueError:
pass
return value
class YamlConfigSource(ConfigSource):
"""YAML 配置文件源"""
def __init__(
self,
file_path: str | Path = "config.yaml",
required: bool = False,
) -> None:
"""
初始化 YAML 配置源
Args:
file_path: 配置文件路径
required: 是否必须存在,为 True 时文件不存在会抛出异常
"""
self._file_path = Path(file_path)
self._required = required
@property
def name(self) -> str:
return f"yaml:{self._file_path}"
@property
def priority(self) -> int:
return 50 # 中等优先级
def load(self) -> dict[str, Any]:
"""
加载 YAML 配置文件
Returns:
配置字典
Raises:
ConfigSourceError: 文件不存在required=True或解析错误
"""
if not self._file_path.exists():
if self._required:
raise ConfigSourceError(
self.name,
f"配置文件不存在: {self._file_path.absolute()}",
)
return {}
try:
with open(self._file_path, encoding="utf-8") as f:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise ConfigSourceError(self.name, f"YAML 解析错误: {e}") from e
if data is None:
return {}
# 处理列表格式(兼容用户示例格式)
if isinstance(data, list):
return self._flatten_list(data)
# 标准字典格式
if isinstance(data, dict):
return self._normalize_keys(data)
raise ConfigSourceError(
self.name,
f"不支持的配置格式,期望 dict 或 list得到 {type(data).__name__}",
)
@staticmethod
def _flatten_list(items: list) -> dict[str, Any]:
"""
将列表格式配置展平为字典
支持格式:
- key: value
- key: value
"""
result: dict[str, Any] = {}
for item in items:
if isinstance(item, dict):
for key, value in item.items():
result[key.lower()] = value
return result
@staticmethod
def _normalize_keys(data: dict) -> dict[str, Any]:
"""键名标准化为小写"""
return {k.lower(): v for k, v in data.items()}
class ConfigLoader:
"""
配置加载器
负责从多个配置源加载并合并配置,按优先级覆盖。
Example:
>>> loader = ConfigLoader()
>>> loader.add_source(YamlConfigSource("config.yaml"))
>>> loader.add_source(EnvConfigSource("SATONANO_"))
>>> config = loader.load()
>>> db_type = loader.get("database_type", "sqlite")
"""
def __init__(self) -> None:
self._sources: list[ConfigSource] = []
self._config: dict[str, Any] = {}
self._loaded = False
def add_source(self, source: ConfigSource) -> ConfigLoader:
"""
添加配置源
Args:
source: 配置源实例
Returns:
self支持链式调用
"""
self._sources.append(source)
self._loaded = False # 标记需要重新加载
return self
def load(self) -> dict[str, Any]:
"""
加载并合并所有配置源
Returns:
合并后的配置字典
"""
# 按优先级升序排序,后加载的覆盖先加载的
sorted_sources = sorted(self._sources, key=lambda s: s.priority)
self._config = {}
for source in sorted_sources:
source_config = source.load()
self._config.update(source_config)
self._loaded = True
return self._config.copy()
def get(self, key: str, default: T = None) -> T | Any:
"""
获取配置值
Args:
key: 配置键名(不区分大小写)
default: 默认值
Returns:
配置值或默认值
"""
if not self._loaded:
self.load()
return self._config.get(key.lower(), default)
def get_str(self, key: str, default: str = "") -> str:
"""获取字符串配置"""
value = self.get(key, default)
return str(value) if value is not None else default
def get_int(self, key: str, default: int = 0) -> int:
"""获取整数配置"""
value = self.get(key, default)
if isinstance(value, int):
return value
try:
return int(value)
except (ValueError, TypeError):
return default
def get_float(self, key: str, default: float = 0.0) -> float:
"""获取浮点数配置"""
value = self.get(key, default)
if isinstance(value, float):
return value
try:
return float(value)
except (ValueError, TypeError):
return default
def get_bool(self, key: str, default: bool = False) -> bool:
"""获取布尔配置"""
value = self.get(key, default)
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in ("true", "yes", "1", "on")
return bool(value)
def require(self, key: str) -> Any:
"""
获取必需的配置值
Args:
key: 配置键名
Returns:
配置值
Raises:
KeyError: 配置不存在
"""
if not self._loaded:
self.load()
key_lower = key.lower()
if key_lower not in self._config:
raise KeyError(f"缺少必需的配置项: {key}")
return self._config[key_lower]
def all(self) -> dict[str, Any]:
"""获取所有配置的副本"""
if not self._loaded:
self.load()
return self._config.copy()
def __contains__(self, key: str) -> bool:
"""检查配置是否存在"""
if not self._loaded:
self.load()
return key.lower() in self._config
def __repr__(self) -> str:
sources = ", ".join(s.name for s in self._sources)
return f"ConfigLoader(sources=[{sources}], loaded={self._loaded})"
def create_config_loader(
yaml_path: str | Path = "config.yaml",
env_prefix: str = "SATONANO_",
yaml_required: bool = False,
) -> ConfigLoader:
"""
创建预配置的配置加载器
Args:
yaml_path: YAML 配置文件路径
env_prefix: 环境变量前缀
yaml_required: YAML 文件是否必须存在
Returns:
配置好的 ConfigLoader 实例
"""
loader = ConfigLoader()
loader.add_source(YamlConfigSource(yaml_path, required=yaml_required))
loader.add_source(EnvConfigSource(env_prefix))
return loader
@lru_cache
def get_config_loader() -> ConfigLoader:
"""获取全局配置加载器单例"""
return create_config_loader()
# 便捷访问的全局实例
config_loader = get_config_loader()

224
app/core/exceptions.py Normal file
View File

@@ -0,0 +1,224 @@
"""
自定义异常类
定义业务层面的异常,便于统一处理和返回合适的 HTTP 响应。
"""
from typing import Any
class AppException(Exception):
"""应用基础异常"""
def __init__(
self,
message: str,
code: str = "APP_ERROR",
details: dict[str, Any] | None = None,
):
self.message = message
self.code = code
self.details = details or {}
super().__init__(message)
class AuthenticationError(AppException):
"""认证错误"""
def __init__(
self,
message: str = "认证失败",
code: str = "AUTHENTICATION_ERROR",
details: dict[str, Any] | None = None,
):
super().__init__(message, code, details)
class InvalidCredentialsError(AuthenticationError):
"""无效凭证"""
def __init__(self, message: str = "用户名或密码错误"):
super().__init__(message, "INVALID_CREDENTIALS")
class TokenError(AuthenticationError):
"""令牌错误"""
def __init__(self, message: str = "令牌无效或已过期"):
super().__init__(message, "TOKEN_ERROR")
class TokenExpiredError(TokenError):
"""令牌过期"""
def __init__(self, message: str = "令牌已过期"):
super().__init__(message)
self.code = "TOKEN_EXPIRED"
class AuthorizationError(AppException):
"""授权错误"""
def __init__(
self,
message: str = "权限不足",
code: str = "AUTHORIZATION_ERROR",
details: dict[str, Any] | None = None,
):
super().__init__(message, code, details)
class ValidationError(AppException):
"""验证错误"""
def __init__(
self,
message: str = "数据验证失败",
code: str = "VALIDATION_ERROR",
details: dict[str, Any] | None = None,
):
super().__init__(message, code, details)
class ResourceNotFoundError(AppException):
"""资源未找到"""
def __init__(
self,
message: str = "资源不存在",
resource_type: str = "resource",
resource_id: Any = None,
):
super().__init__(
message,
"RESOURCE_NOT_FOUND",
{"resource_type": resource_type, "resource_id": resource_id},
)
class ResourceConflictError(AppException):
"""资源冲突(如重复创建)"""
def __init__(
self,
message: str = "资源已存在",
code: str = "RESOURCE_CONFLICT",
details: dict[str, Any] | None = None,
):
super().__init__(message, code, details)
class UserNotFoundError(ResourceNotFoundError):
"""用户不存在"""
def __init__(self, user_id: Any = None):
super().__init__("用户不存在", "user", user_id)
class UserAlreadyExistsError(ResourceConflictError):
"""用户已存在"""
def __init__(self, field: str = "username"):
super().__init__(
f"{field}已被注册",
"USER_ALREADY_EXISTS",
{"field": field},
)
class UserDisabledError(AuthenticationError):
"""用户被禁用"""
def __init__(self):
super().__init__("账户已被禁用", "USER_DISABLED")
class PasswordValidationError(ValidationError):
"""密码验证错误"""
def __init__(self, message: str = "密码不符合要求"):
super().__init__(message, "PASSWORD_VALIDATION_ERROR")
# ============================================================
# 余额相关异常
# ============================================================
class InsufficientBalanceError(AppException):
"""余额不足"""
def __init__(self, required: int, available: int):
super().__init__(
f"余额不足,需要 {required / 1000:.2f},当前可用 {available / 1000:.2f}",
"INSUFFICIENT_BALANCE",
{"required_units": required, "available_units": available},
)
class DuplicateTransactionError(AppException):
"""重复交易"""
def __init__(self, idempotency_key: str):
super().__init__(
"该交易已处理",
"DUPLICATE_TRANSACTION",
{"idempotency_key": idempotency_key},
)
class ConcurrencyError(AppException):
"""并发冲突"""
def __init__(self):
super().__init__(
"操作冲突,请重试",
"CONCURRENCY_ERROR",
)
# ============================================================
# 兑换码相关异常
# ============================================================
class RedeemCodeNotFoundError(AppException):
"""兑换码不存在"""
def __init__(self, code: str):
super().__init__(
"兑换码不存在",
"REDEEM_CODE_NOT_FOUND",
{"code": code},
)
class RedeemCodeInvalidError(AppException):
"""兑换码无效"""
def __init__(self, code: str, reason: str):
super().__init__(
f"兑换码无效: {reason}",
"REDEEM_CODE_INVALID",
{"code": code, "reason": reason},
)
class RedeemCodeExpiredError(AppException):
"""兑换码已过期"""
def __init__(self, code: str):
super().__init__(
"兑换码已过期",
"REDEEM_CODE_EXPIRED",
{"code": code},
)
class RedeemCodeUsedError(AppException):
"""兑换码已使用"""
def __init__(self, code: str):
super().__init__(
"兑换码已使用",
"REDEEM_CODE_USED",
{"code": code},
)

155
app/core/security.py Normal file
View File

@@ -0,0 +1,155 @@
"""
安全相关功能
包括密码哈希、JWT 令牌生成与验证。
使用 Argon2 作为密码哈希算法(目前最安全的选择)。
"""
from datetime import datetime, timedelta, timezone
from typing import Any
import jwt
from argon2 import PasswordHasher
from argon2.exceptions import InvalidHashError, VerifyMismatchError
from app.core.config import settings
# Argon2 密码哈希器,使用推荐的安全参数
_password_hasher = PasswordHasher(
time_cost=3, # 迭代次数
memory_cost=65536, # 内存使用 (64MB)
parallelism=4, # 并行度
)
def hash_password(password: str) -> str:
"""
对密码进行哈希处理
Args:
password: 明文密码
Returns:
Argon2 哈希字符串
"""
return _password_hasher.hash(password)
def verify_password(password: str, hashed_password: str) -> bool:
"""
验证密码是否匹配
Args:
password: 明文密码
hashed_password: 已哈希的密码
Returns:
密码是否正确
"""
try:
_password_hasher.verify(hashed_password, password)
return True
except (VerifyMismatchError, InvalidHashError):
return False
def password_needs_rehash(hashed_password: str) -> bool:
"""
检查密码是否需要重新哈希(参数升级时使用)
Args:
hashed_password: 已哈希的密码
Returns:
是否需要重新哈希
"""
return _password_hasher.check_needs_rehash(hashed_password)
def create_access_token(
subject: str | int,
expires_delta: timedelta | None = None,
extra_claims: dict[str, Any] | None = None,
) -> str:
"""
创建访问令牌
Args:
subject: 令牌主体通常是用户ID
expires_delta: 过期时间增量
extra_claims: 额外的声明数据
Returns:
JWT 访问令牌
"""
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(minutes=settings.access_token_expire_minutes)
payload = {
"sub": str(subject),
"iat": now,
"exp": expire,
"type": "access",
}
if extra_claims:
payload.update(extra_claims)
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
def create_refresh_token(
subject: str | int,
expires_delta: timedelta | None = None,
) -> str:
"""
创建刷新令牌
Args:
subject: 令牌主体通常是用户ID
expires_delta: 过期时间增量
Returns:
JWT 刷新令牌
"""
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(days=settings.refresh_token_expire_days)
payload = {
"sub": str(subject),
"iat": now,
"exp": expire,
"type": "refresh",
}
return jwt.encode(payload, settings.secret_key, algorithm=settings.algorithm)
def decode_token(token: str) -> dict[str, Any]:
"""
解码并验证 JWT 令牌
Args:
token: JWT 令牌字符串
Returns:
解码后的负载数据
Raises:
jwt.InvalidTokenError: 令牌无效
jwt.ExpiredSignatureError: 令牌已过期
"""
return jwt.decode(
token,
settings.secret_key,
algorithms=[settings.algorithm],
)

100
app/database.py Normal file
View File

@@ -0,0 +1,100 @@
"""
数据库连接与会话管理
使用 SQLAlchemy 2.0 异步模式。
"""
import logging
from collections.abc import AsyncGenerator
from pathlib import Path
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from app.core.config import settings
logger = logging.getLogger(__name__)
class Base(DeclarativeBase):
"""SQLAlchemy 声明式基类"""
pass
def _ensure_sqlite_dir() -> None:
"""
确保 SQLite 数据库目录存在
如果配置使用 SQLite自动创建数据库文件所在的目录。
"""
if settings.database_type != "sqlite":
return
db_path = Path(settings.database_sqlite_path)
db_dir = db_path.parent
if not db_dir.exists():
logger.info(f"创建数据库目录: {db_dir.absolute()}")
db_dir.mkdir(parents=True, exist_ok=True)
# 确保 SQLite 目录存在
_ensure_sqlite_dir()
# 创建异步引擎
engine = create_async_engine(
settings.database_url,
echo=settings.database_echo,
future=True,
)
# 创建异步会话工厂
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
获取数据库会话(依赖注入用)
Yields:
异步数据库会话
"""
async with async_session_factory() as session:
try:
yield session
finally:
await session.close()
async def init_db() -> None:
"""初始化数据库(创建所有表)"""
# 导入所有模型以确保它们被注册
from app.models import ( # noqa: F401
User,
UserBalance,
BalanceTransaction,
RedeemCode,
RedeemCodeBatch,
RedeemCodeUsageLog,
)
logger.info(f"初始化数据库: {settings.database_url}")
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info("数据库初始化完成")
async def close_db() -> None:
"""关闭数据库连接"""
await engine.dispose()

268
app/main.py Normal file
View File

@@ -0,0 +1,268 @@
"""
FastAPI 应用入口
配置应用实例、中间件、异常处理器等。
"""
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from app.api.v1.router import api_router
from app.core.config import settings
from app.core.exceptions import (
AppException,
AuthenticationError,
AuthorizationError,
ResourceConflictError,
ResourceNotFoundError,
ValidationError,
)
from app.database import close_db, init_db
def get_frontend_dir() -> Path:
"""
获取前端静态文件目录
支持相对路径和绝对路径配置
"""
path = Path(settings.frontend_static_path)
if path.is_absolute():
return path
# 相对路径基于项目根目录
return Path(__file__).parent.parent / path
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""
应用生命周期管理
启动时初始化数据库,关闭时释放资源。
"""
# 启动时
await init_db()
yield
# 关闭时
await close_db()
def create_application() -> FastAPI:
"""
创建 FastAPI 应用实例
Returns:
配置完成的 FastAPI 应用
"""
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description="现代化用户认证系统 API",
docs_url="/docs" if not settings.is_production else None,
redoc_url="/redoc" if not settings.is_production else None,
openapi_url="/openapi.json" if not settings.is_production else None,
lifespan=lifespan,
)
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"] if settings.debug else [],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册异常处理器
register_exception_handlers(app)
# 注册路由
app.include_router(api_router, prefix=settings.api_v1_prefix)
# 健康检查端点
@app.get("/health", tags=["健康检查"])
async def health_check():
"""健康检查接口"""
return {"status": "healthy", "version": settings.app_version}
# 挂载前端静态文件(当静态文件目录存在时)
frontend_dir = get_frontend_dir()
if frontend_dir.exists():
setup_frontend_routes(app, frontend_dir)
return app
def setup_frontend_routes(app: FastAPI, frontend_dir: Path) -> None:
"""
配置前端静态文件路由
Args:
app: FastAPI 应用实例
frontend_dir: 前端静态文件目录
"""
# 挂载 _next 静态资源目录Next.js 构建产物)
next_static_dir = frontend_dir / "_next"
if next_static_dir.exists():
app.mount(
"/_next",
StaticFiles(directory=str(next_static_dir)),
name="next_static",
)
# 挂载 assets 目录(通用静态资源)
assets_dir = frontend_dir / "assets"
if assets_dir.exists():
app.mount(
"/assets",
StaticFiles(directory=str(assets_dir)),
name="assets",
)
# 挂载 static 目录(通用静态资源)
static_dir = frontend_dir / "static"
if static_dir.exists():
app.mount(
"/static",
StaticFiles(directory=str(static_dir)),
name="static",
)
# SPA 路由处理 - 必须在最后注册,以避免覆盖 API 路由
@app.get("/{full_path:path}", include_in_schema=False)
async def serve_spa(request: Request, full_path: str):
"""
SPA 路由处理
- 对于静态文件请求,直接返回文件
- 对于 SPA 路由请求,返回 index.html
"""
# 跳过 API 路由和健康检查
if full_path.startswith("api/") or full_path == "health":
return JSONResponse(
status_code=404,
content={"success": False, "message": "Not Found"},
)
# 尝试查找静态文件
file_path = frontend_dir / full_path
# 如果是目录,尝试查找 index.html
if file_path.is_dir():
file_path = file_path / "index.html"
# 如果文件存在,返回文件
if file_path.is_file():
return FileResponse(
file_path,
headers=_get_cache_headers(full_path),
)
# 尝试添加 .html 后缀Next.js 静态导出格式)
html_path = frontend_dir / f"{full_path}.html"
if html_path.is_file():
return FileResponse(html_path)
# 默认返回 index.htmlSPA 路由回退)
index_path = frontend_dir / "index.html"
if index_path.is_file():
return FileResponse(index_path)
# 前端文件不存在
return JSONResponse(
status_code=404,
content={"success": False, "message": "Not Found"},
)
def _get_cache_headers(path: str) -> dict[str, str]:
"""
根据文件类型返回适当的缓存头
Args:
path: 请求路径
Returns:
缓存相关的 HTTP 头
"""
# 静态资源(带 hash 的文件)使用长期缓存
if "/_next/" in path or path.startswith("_next/"):
return {"Cache-Control": "public, max-age=31536000, immutable"}
# 其他静态文件使用短期缓存
static_extensions = {".js", ".css", ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".woff", ".woff2"}
if any(path.endswith(ext) for ext in static_extensions):
return {"Cache-Control": "public, max-age=86400"}
# HTML 文件不缓存
return {"Cache-Control": "no-cache"}
def register_exception_handlers(app: FastAPI) -> None:
"""注册全局异常处理器"""
@app.exception_handler(AppException)
async def app_exception_handler(
request: Request,
exc: AppException,
) -> JSONResponse:
"""应用异常处理"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
if isinstance(exc, AuthenticationError):
status_code = status.HTTP_401_UNAUTHORIZED
elif isinstance(exc, AuthorizationError):
status_code = status.HTTP_403_FORBIDDEN
elif isinstance(exc, ResourceNotFoundError):
status_code = status.HTTP_404_NOT_FOUND
elif isinstance(exc, ResourceConflictError):
status_code = status.HTTP_409_CONFLICT
elif isinstance(exc, ValidationError):
status_code = status.HTTP_422_UNPROCESSABLE_ENTITY
return JSONResponse(
status_code=status_code,
content={
"success": False,
"message": exc.message,
"code": exc.code,
"details": exc.details,
},
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
"""请求验证异常处理"""
errors = []
for error in exc.errors():
field = ".".join(str(loc) for loc in error["loc"])
errors.append({
"field": field,
"message": error["msg"],
"type": error["type"],
})
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content={
"success": False,
"message": "请求数据验证失败",
"code": "VALIDATION_ERROR",
"details": {"errors": errors},
},
)
# 创建应用实例
app = create_application()

28
app/models/__init__.py Normal file
View File

@@ -0,0 +1,28 @@
"""数据库模型"""
from app.models.user import User
from app.models.balance import (
UserBalance,
BalanceTransaction,
TransactionType,
TransactionStatus,
)
from app.models.redeem_code import (
RedeemCode,
RedeemCodeBatch,
RedeemCodeUsageLog,
RedeemCodeStatus,
)
__all__ = [
"User",
"UserBalance",
"BalanceTransaction",
"TransactionType",
"TransactionStatus",
"RedeemCode",
"RedeemCodeBatch",
"RedeemCodeUsageLog",
"RedeemCodeStatus",
]

320
app/models/balance.py Normal file
View File

@@ -0,0 +1,320 @@
"""
余额与交易模型
定义用户余额、余额交易记录相关的数据表结构。
设计说明:
- 余额内部以无符号整数存储(单位额度),避免浮点精度问题
- 1.00 显示余额 = 1000 单位额度(精度 0.001
- 所有金额操作都在单位额度层面进行,只在展示时转换
"""
from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING
from uuid import uuid4
from sqlalchemy import (
BigInteger,
DateTime,
Enum as SQLEnum,
ForeignKey,
Index,
String,
Text,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
if TYPE_CHECKING:
from app.models.user import User
def generate_uuid() -> str:
"""生成 UUID 字符串"""
return str(uuid4())
def utc_now() -> datetime:
"""获取当前 UTC 时间"""
return datetime.now(timezone.utc)
class TransactionType(str, Enum):
"""交易类型枚举"""
RECHARGE = "recharge" # 充值(兑换码等)
DEDUCTION = "deduction" # 扣款API 调用等)
REFUND = "refund" # 退款
ADJUSTMENT = "adjustment" # 管理员调整
TRANSFER_IN = "transfer_in" # 转入
TRANSFER_OUT = "transfer_out" # 转出
class TransactionStatus(str, Enum):
"""交易状态枚举"""
PENDING = "pending" # 待处理
COMPLETED = "completed" # 已完成
FAILED = "failed" # 失败
CANCELLED = "cancelled" # 已取消
class UserBalance(Base):
"""
用户余额模型
独立的余额表,便于扩展(如多币种、账户类型)和锁管理
"""
__tablename__ = "user_balances"
# 主键
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=generate_uuid,
comment="余额记录唯一标识",
)
# 关联用户(一对一)
user_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
unique=True,
nullable=False,
index=True,
comment="关联用户 ID",
)
# 余额信息(内部以整数单位存储)
balance: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
comment="当前余额单位额度1000 = 1.00 显示余额)",
)
# 冻结金额(用于处理中的交易)
frozen_balance: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
comment="冻结余额(单位额度)",
)
# 累计统计
total_recharged: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
comment="累计充值(单位额度)",
)
total_consumed: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
comment="累计消费(单位额度)",
)
# 乐观锁版本号
version: Mapped[int] = mapped_column(
BigInteger,
default=0,
nullable=False,
comment="版本号(乐观锁)",
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
server_default=func.now(),
nullable=False,
comment="创建时间",
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
onupdate=utc_now,
server_default=func.now(),
nullable=False,
comment="更新时间",
)
# 关系
user: Mapped["User"] = relationship(
"User",
back_populates="balance_account",
lazy="selectin",
)
transactions: Mapped[list["BalanceTransaction"]] = relationship(
"BalanceTransaction",
back_populates="balance_account",
lazy="selectin",
order_by="desc(BalanceTransaction.created_at)",
)
@property
def available_balance(self) -> int:
"""可用余额(总余额 - 冻结余额)"""
return self.balance - self.frozen_balance
@property
def display_balance(self) -> str:
"""显示余额2 位小数)"""
return f"{self.balance / 1000:.2f}"
@property
def display_available_balance(self) -> str:
"""显示可用余额2 位小数)"""
return f"{self.available_balance / 1000:.2f}"
def __repr__(self) -> str:
return f"<UserBalance(user_id={self.user_id!r}, balance={self.display_balance})>"
class BalanceTransaction(Base):
"""
余额交易记录模型
记录所有余额变动,用于审计和对账
"""
__tablename__ = "balance_transactions"
__table_args__ = (
Index("ix_balance_transactions_user_created", "user_id", "created_at"),
Index("ix_balance_transactions_type_status", "transaction_type", "status"),
)
# 主键
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=generate_uuid,
comment="交易记录唯一标识",
)
# 关联
user_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="关联用户 ID",
)
balance_account_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("user_balances.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="关联余额账户 ID",
)
# 交易信息
transaction_type: Mapped[TransactionType] = mapped_column(
SQLEnum(TransactionType),
nullable=False,
comment="交易类型",
)
status: Mapped[TransactionStatus] = mapped_column(
SQLEnum(TransactionStatus),
default=TransactionStatus.COMPLETED,
nullable=False,
comment="交易状态",
)
# 金额(整数单位)
amount: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
comment="交易金额(单位额度,正数表示收入,负数表示支出)",
)
balance_before: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
comment="交易前余额(单位额度)",
)
balance_after: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
comment="交易后余额(单位额度)",
)
# 业务关联
reference_type: Mapped[str | None] = mapped_column(
String(64),
nullable=True,
index=True,
comment="关联业务类型(如 redeem_code、api_call",
)
reference_id: Mapped[str | None] = mapped_column(
String(64),
nullable=True,
index=True,
comment="关联业务 ID",
)
# 描述
description: Mapped[str | None] = mapped_column(
String(255),
nullable=True,
comment="交易描述",
)
remark: Mapped[str | None] = mapped_column(
Text,
nullable=True,
comment="备注(内部使用)",
)
# 操作人(管理员调整时记录)
operator_id: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
comment="操作人 ID管理员调整时",
)
# 幂等键(防止重复提交)
idempotency_key: Mapped[str | None] = mapped_column(
String(64),
unique=True,
nullable=True,
comment="幂等键",
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
server_default=func.now(),
nullable=False,
comment="创建时间",
)
# 关系
user: Mapped["User"] = relationship(
"User",
foreign_keys=[user_id],
lazy="selectin",
)
balance_account: Mapped["UserBalance"] = relationship(
"UserBalance",
back_populates="transactions",
lazy="selectin",
)
@property
def display_amount(self) -> str:
"""显示金额带符号2 位小数)"""
return f"{self.amount / 1000:+.2f}"
def __repr__(self) -> str:
return (
f"<BalanceTransaction(id={self.id!r}, "
f"type={self.transaction_type.value}, "
f"amount={self.display_amount})>"
)

409
app/models/redeem_code.py Normal file
View File

@@ -0,0 +1,409 @@
"""
兑换码模型
定义余额兑换码相关的数据表结构。
设计说明:
- 兑换码支持批量生成、导入导出
- 记录完整的使用日志
- 支持设置有效期和使用限制
"""
from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING
from uuid import uuid4
import secrets
import string
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum as SQLEnum,
ForeignKey,
Index,
Integer,
String,
Text,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
if TYPE_CHECKING:
from app.models.user import User
def generate_uuid() -> str:
"""生成 UUID 字符串"""
return str(uuid4())
def utc_now() -> datetime:
"""获取当前 UTC 时间"""
return datetime.now(timezone.utc)
def generate_redeem_code(length: int = 16) -> str:
"""
生成兑换码
格式:大写字母和数字,分段显示(如 XXXX-XXXX-XXXX-XXXX
排除容易混淆的字符0/O, 1/I/L
"""
# 可用字符集(排除易混淆字符)
alphabet = "ABCDEFGHJKMNPQRSTUVWXYZ23456789"
code = "".join(secrets.choice(alphabet) for _ in range(length))
# 每 4 个字符用连字符分隔
return "-".join(code[i:i + 4] for i in range(0, len(code), 4))
class RedeemCodeStatus(str, Enum):
"""兑换码状态枚举"""
ACTIVE = "active" # 可用
USED = "used" # 已使用
DISABLED = "disabled" # 已禁用
EXPIRED = "expired" # 已过期
class RedeemCodeBatch(Base):
"""
兑换码批次模型
用于管理批量生成的兑换码
"""
__tablename__ = "redeem_code_batches"
# 主键
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=generate_uuid,
comment="批次唯一标识",
)
# 批次信息
name: Mapped[str] = mapped_column(
String(128),
nullable=False,
comment="批次名称",
)
description: Mapped[str | None] = mapped_column(
Text,
nullable=True,
comment="批次描述",
)
# 面值(单位额度)
face_value: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
comment="面值单位额度1000 = 1.00",
)
# 数量统计
total_count: Mapped[int] = mapped_column(
Integer,
default=0,
nullable=False,
comment="生成总数",
)
used_count: Mapped[int] = mapped_column(
Integer,
default=0,
nullable=False,
comment="已使用数量",
)
# 创建者
created_by: Mapped[str] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
comment="创建者 ID",
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
server_default=func.now(),
nullable=False,
comment="创建时间",
)
# 关系
codes: Mapped[list["RedeemCode"]] = relationship(
"RedeemCode",
back_populates="batch",
lazy="selectin",
)
creator: Mapped["User"] = relationship(
"User",
lazy="selectin",
)
@property
def display_face_value(self) -> str:
"""显示面值2 位小数)"""
return f"{self.face_value / 1000:.2f}"
def __repr__(self) -> str:
return f"<RedeemCodeBatch(id={self.id!r}, name={self.name!r})>"
class RedeemCode(Base):
"""
兑换码模型
单个兑换码记录
"""
__tablename__ = "redeem_codes"
__table_args__ = (
Index("ix_redeem_codes_status_expires", "status", "expires_at"),
)
# 主键
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=generate_uuid,
comment="兑换码记录唯一标识",
)
# 兑换码(唯一索引)
code: Mapped[str] = mapped_column(
String(32),
unique=True,
nullable=False,
index=True,
default=generate_redeem_code,
comment="兑换码",
)
# 批次关联(可选)
batch_id: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("redeem_code_batches.id", ondelete="SET NULL"),
nullable=True,
index=True,
comment="关联批次 ID",
)
# 面值(单位额度)
face_value: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
comment="面值单位额度1000 = 1.00",
)
# 状态
status: Mapped[RedeemCodeStatus] = mapped_column(
SQLEnum(RedeemCodeStatus),
default=RedeemCodeStatus.ACTIVE,
nullable=False,
index=True,
comment="兑换码状态",
)
# 使用限制
max_uses: Mapped[int] = mapped_column(
Integer,
default=1,
nullable=False,
comment="最大使用次数",
)
used_count: Mapped[int] = mapped_column(
Integer,
default=0,
nullable=False,
comment="已使用次数",
)
# 有效期
expires_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="过期时间",
)
# 使用信息
used_by: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
comment="使用者 ID最后使用",
)
used_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="使用时间(最后使用)",
)
# 备注
remark: Mapped[str | None] = mapped_column(
Text,
nullable=True,
comment="备注",
)
# 创建者
created_by: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
comment="创建者 ID",
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
server_default=func.now(),
nullable=False,
comment="创建时间",
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
onupdate=utc_now,
server_default=func.now(),
nullable=False,
comment="更新时间",
)
# 关系
batch: Mapped["RedeemCodeBatch | None"] = relationship(
"RedeemCodeBatch",
back_populates="codes",
lazy="selectin",
)
usage_logs: Mapped[list["RedeemCodeUsageLog"]] = relationship(
"RedeemCodeUsageLog",
back_populates="redeem_code",
lazy="selectin",
order_by="desc(RedeemCodeUsageLog.created_at)",
)
@property
def display_face_value(self) -> str:
"""显示面值2 位小数)"""
return f"{self.face_value / 1000:.2f}"
@property
def is_valid(self) -> bool:
"""检查兑换码是否有效"""
if self.status != RedeemCodeStatus.ACTIVE:
return False
if self.used_count >= self.max_uses:
return False
if self.expires_at and self.expires_at < utc_now():
return False
return True
def __repr__(self) -> str:
return f"<RedeemCode(code={self.code!r}, status={self.status.value})>"
class RedeemCodeUsageLog(Base):
"""
兑换码使用日志模型
记录每次兑换的详细信息
"""
__tablename__ = "redeem_code_usage_logs"
__table_args__ = (
Index("ix_redeem_usage_user_created", "user_id", "created_at"),
)
# 主键
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=generate_uuid,
comment="日志唯一标识",
)
# 关联
redeem_code_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("redeem_codes.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="关联兑换码 ID",
)
user_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
comment="使用者 ID",
)
transaction_id: Mapped[str | None] = mapped_column(
String(36),
ForeignKey("balance_transactions.id", ondelete="SET NULL"),
nullable=True,
comment="关联交易记录 ID",
)
# 兑换信息快照
code_snapshot: Mapped[str] = mapped_column(
String(32),
nullable=False,
comment="兑换码快照",
)
face_value: Mapped[int] = mapped_column(
BigInteger,
nullable=False,
comment="面值(单位额度)",
)
# 客户端信息
ip_address: Mapped[str | None] = mapped_column(
String(64),
nullable=True,
comment="客户端 IP",
)
user_agent: Mapped[str | None] = mapped_column(
String(512),
nullable=True,
comment="User Agent",
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
server_default=func.now(),
nullable=False,
comment="使用时间",
)
# 关系
redeem_code: Mapped["RedeemCode"] = relationship(
"RedeemCode",
back_populates="usage_logs",
lazy="selectin",
)
user: Mapped["User"] = relationship(
"User",
lazy="selectin",
)
@property
def display_face_value(self) -> str:
"""显示面值2 位小数)"""
return f"{self.face_value / 1000:.2f}"
def __repr__(self) -> str:
return (
f"<RedeemCodeUsageLog(code={self.code_snapshot!r}, "
f"user_id={self.user_id!r})>"
)

141
app/models/user.py Normal file
View File

@@ -0,0 +1,141 @@
"""
用户模型
定义用户数据表结构。
"""
from datetime import datetime, timezone
from typing import TYPE_CHECKING
from uuid import uuid4
from sqlalchemy import Boolean, DateTime, String, Text, func
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.database import Base
if TYPE_CHECKING:
from app.models.balance import UserBalance
def generate_uuid() -> str:
"""生成 UUID 字符串"""
return str(uuid4())
def utc_now() -> datetime:
"""获取当前 UTC 时间"""
return datetime.now(timezone.utc)
class User(Base):
"""用户模型"""
__tablename__ = "users"
# 主键:使用 UUID 字符串
id: Mapped[str] = mapped_column(
String(36),
primary_key=True,
default=generate_uuid,
comment="用户唯一标识",
)
# 账户信息
username: Mapped[str] = mapped_column(
String(32),
unique=True,
index=True,
nullable=False,
comment="用户名",
)
email: Mapped[str | None] = mapped_column(
String(255),
unique=True,
index=True,
nullable=True,
comment="邮箱地址",
)
hashed_password: Mapped[str | None] = mapped_column(
String(255),
nullable=True, # OAuth2 用户可能没有密码
comment="密码哈希",
)
# OAuth2 关联信息
oauth_provider: Mapped[str | None] = mapped_column(
String(32),
nullable=True,
index=True,
comment="OAuth2 提供商(如 linuxdo",
)
oauth_user_id: Mapped[str | None] = mapped_column(
String(128),
nullable=True,
index=True,
comment="OAuth2 用户 ID",
)
# 用户状态
is_active: Mapped[bool] = mapped_column(
Boolean,
default=True,
nullable=False,
comment="是否激活",
)
is_superuser: Mapped[bool] = mapped_column(
Boolean,
default=False,
nullable=False,
comment="是否为超级管理员",
)
# 个人信息
nickname: Mapped[str | None] = mapped_column(
String(64),
nullable=True,
comment="昵称",
)
avatar_url: Mapped[str | None] = mapped_column(
String(512),
nullable=True,
comment="头像 URL",
)
bio: Mapped[str | None] = mapped_column(
Text,
nullable=True,
comment="个人简介",
)
# 时间戳
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
server_default=func.now(),
nullable=False,
comment="创建时间",
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
default=utc_now,
onupdate=utc_now,
server_default=func.now(),
nullable=False,
comment="更新时间",
)
last_login_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
comment="最后登录时间",
)
# 关系
balance_account: Mapped["UserBalance | None"] = relationship(
"UserBalance",
back_populates="user",
uselist=False,
lazy="selectin",
)
def __repr__(self) -> str:
return f"<User(id={self.id!r}, username={self.username!r})>"

View File

@@ -0,0 +1,19 @@
"""数据仓库层"""
from app.repositories.user import UserRepository
from app.repositories.balance import BalanceRepository, TransactionRepository
from app.repositories.redeem_code import (
RedeemCodeRepository,
RedeemCodeBatchRepository,
RedeemCodeUsageLogRepository,
)
__all__ = [
"UserRepository",
"BalanceRepository",
"TransactionRepository",
"RedeemCodeRepository",
"RedeemCodeBatchRepository",
"RedeemCodeUsageLogRepository",
]

378
app/repositories/balance.py Normal file
View File

@@ -0,0 +1,378 @@
"""
余额仓库
处理余额相关的数据库操作。
设计说明:
- 使用乐观锁version防止并发更新冲突
- 提供行级锁支持(悲观锁)用于关键操作
"""
from datetime import datetime
from typing import Any
from sqlalchemy import select, update, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.balance import (
UserBalance,
BalanceTransaction,
TransactionType,
TransactionStatus,
)
from app.repositories.base import BaseRepository
class BalanceRepository(BaseRepository[UserBalance]):
"""余额仓库"""
model = UserBalance
async def get_by_user_id(self, user_id: str) -> UserBalance | None:
"""
通过用户 ID 获取余额账户
Args:
user_id: 用户 ID
Returns:
余额账户或 None
"""
stmt = select(UserBalance).where(UserBalance.user_id == user_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_user_id_for_update(self, user_id: str) -> UserBalance | None:
"""
通过用户 ID 获取余额账户(加行级锁)
用于需要原子性更新的场景,如扣款操作。
Args:
user_id: 用户 ID
Returns:
余额账户或 None
"""
stmt = (
select(UserBalance)
.where(UserBalance.user_id == user_id)
.with_for_update() # 行级锁
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_or_create(self, user_id: str) -> UserBalance:
"""
获取或创建余额账户
如果用户没有余额账户,自动创建一个。
Args:
user_id: 用户 ID
Returns:
余额账户
"""
balance = await self.get_by_user_id(user_id)
if balance is None:
balance = await self.create(user_id=user_id)
return balance
async def get_or_create_for_update(self, user_id: str) -> UserBalance:
"""
获取或创建余额账户(加行级锁)
Args:
user_id: 用户 ID
Returns:
余额账户
"""
balance = await self.get_by_user_id_for_update(user_id)
if balance is None:
balance = await self.create(user_id=user_id)
# 重新获取并加锁
await self.session.flush()
balance = await self.get_by_user_id_for_update(user_id)
return balance # type: ignore
async def update_balance_optimistic(
self,
balance: UserBalance,
delta: int,
*,
is_recharge: bool = False,
is_consumption: bool = False,
) -> bool:
"""
使用乐观锁更新余额
通过版本号检查确保并发安全。
Args:
balance: 余额账户
delta: 变化量(正数增加,负数减少)
is_recharge: 是否为充值
is_consumption: 是否为消费
Returns:
是否更新成功
"""
current_version = balance.version
new_balance = balance.balance + delta
# 构建更新语句
update_values: dict[str, Any] = {
"balance": new_balance,
"version": current_version + 1,
}
if is_recharge and delta > 0:
update_values["total_recharged"] = balance.total_recharged + delta
if is_consumption and delta < 0:
update_values["total_consumed"] = balance.total_consumed + abs(delta)
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version, # 乐观锁检查
)
)
.values(**update_values)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
# 更新成功,刷新对象
balance.balance = new_balance
balance.version = current_version + 1
if is_recharge and delta > 0:
balance.total_recharged += delta
if is_consumption and delta < 0:
balance.total_consumed += abs(delta)
return True
return False
async def freeze_balance(
self,
balance: UserBalance,
amount: int,
) -> bool:
"""
冻结余额
Args:
balance: 余额账户
amount: 冻结金额(正数)
Returns:
是否成功
"""
if amount <= 0:
return False
if balance.available_balance < amount:
return False
current_version = balance.version
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version,
UserBalance.balance - UserBalance.frozen_balance >= amount,
)
)
.values(
frozen_balance=UserBalance.frozen_balance + amount,
version=current_version + 1,
)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
balance.frozen_balance += amount
balance.version = current_version + 1
return True
return False
async def unfreeze_balance(
self,
balance: UserBalance,
amount: int,
) -> bool:
"""
解冻余额
Args:
balance: 余额账户
amount: 解冻金额(正数)
Returns:
是否成功
"""
if amount <= 0:
return False
if balance.frozen_balance < amount:
return False
current_version = balance.version
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version,
UserBalance.frozen_balance >= amount,
)
)
.values(
frozen_balance=UserBalance.frozen_balance - amount,
version=current_version + 1,
)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
balance.frozen_balance -= amount
balance.version = current_version + 1
return True
return False
class TransactionRepository(BaseRepository[BalanceTransaction]):
"""交易记录仓库"""
model = BalanceTransaction
async def get_by_user_id(
self,
user_id: str,
*,
offset: int = 0,
limit: int = 20,
transaction_type: TransactionType | None = None,
status: TransactionStatus | None = None,
) -> list[BalanceTransaction]:
"""
获取用户的交易记录
Args:
user_id: 用户 ID
offset: 偏移量
limit: 限制数量
transaction_type: 交易类型过滤
status: 状态过滤
Returns:
交易记录列表
"""
stmt = select(BalanceTransaction).where(
BalanceTransaction.user_id == user_id
)
if transaction_type:
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
if status:
stmt = stmt.where(BalanceTransaction.status == status)
stmt = (
stmt
.order_by(BalanceTransaction.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count_by_user_id(
self,
user_id: str,
*,
transaction_type: TransactionType | None = None,
status: TransactionStatus | None = None,
) -> int:
"""
统计用户的交易记录数量
Args:
user_id: 用户 ID
transaction_type: 交易类型过滤
status: 状态过滤
Returns:
记录数量
"""
from sqlalchemy import func
stmt = select(func.count()).select_from(BalanceTransaction).where(
BalanceTransaction.user_id == user_id
)
if transaction_type:
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
if status:
stmt = stmt.where(BalanceTransaction.status == status)
result = await self.session.execute(stmt)
return result.scalar() or 0
async def get_by_idempotency_key(
self,
idempotency_key: str,
) -> BalanceTransaction | None:
"""
通过幂等键获取交易记录
用于防止重复提交。
Args:
idempotency_key: 幂等键
Returns:
交易记录或 None
"""
stmt = select(BalanceTransaction).where(
BalanceTransaction.idempotency_key == idempotency_key
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_reference(
self,
reference_type: str,
reference_id: str,
) -> list[BalanceTransaction]:
"""
通过业务关联获取交易记录
Args:
reference_type: 关联业务类型
reference_id: 关联业务 ID
Returns:
交易记录列表
"""
stmt = (
select(BalanceTransaction)
.where(
and_(
BalanceTransaction.reference_type == reference_type,
BalanceTransaction.reference_id == reference_id,
)
)
.order_by(BalanceTransaction.created_at.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())

138
app/repositories/base.py Normal file
View File

@@ -0,0 +1,138 @@
"""
基础仓库类
提供通用的 CRUD 操作封装。
"""
from typing import Any, Generic, TypeVar
from uuid import uuid4
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import Base
ModelT = TypeVar("ModelT", bound=Base)
class BaseRepository(Generic[ModelT]):
"""
基础仓库类
提供通用的数据库操作方法。
"""
model: type[ModelT]
def __init__(self, session: AsyncSession):
"""
初始化仓库
Args:
session: 异步数据库会话
"""
self.session = session
async def get_by_id(self, id: str) -> ModelT | None:
"""
通过 ID 获取实体
Args:
id: 实体 ID
Returns:
实体对象或 None
"""
return await self.session.get(self.model, id)
async def get_all(
self,
*,
offset: int = 0,
limit: int = 100,
) -> list[ModelT]:
"""
获取所有实体
Args:
offset: 偏移量
limit: 限制数量
Returns:
实体列表
"""
stmt = select(self.model).offset(offset).limit(limit)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count(self) -> int:
"""
获取实体总数
Returns:
实体数量
"""
stmt = select(func.count()).select_from(self.model)
result = await self.session.execute(stmt)
return result.scalar() or 0
async def create(self, **kwargs: Any) -> ModelT:
"""
创建新实体
Args:
**kwargs: 实体属性
Returns:
新创建的实体
"""
if "id" not in kwargs:
kwargs["id"] = str(uuid4())
entity = self.model(**kwargs)
self.session.add(entity)
await self.session.flush()
await self.session.refresh(entity)
return entity
async def update(
self,
entity: ModelT,
**kwargs: Any,
) -> ModelT:
"""
更新实体
Args:
entity: 要更新的实体
**kwargs: 要更新的属性
Returns:
更新后的实体
"""
for key, value in kwargs.items():
if hasattr(entity, key):
setattr(entity, key, value)
await self.session.flush()
await self.session.refresh(entity)
return entity
async def delete(self, entity: ModelT) -> None:
"""
删除实体
Args:
entity: 要删除的实体
"""
await self.session.delete(entity)
await self.session.flush()
async def commit(self) -> None:
"""提交事务"""
await self.session.commit()
async def rollback(self) -> None:
"""回滚事务"""
await self.session.rollback()

View File

@@ -0,0 +1,462 @@
"""
兑换码仓库
处理兑换码相关的数据库操作。
"""
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import select, update, and_, or_, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.redeem_code import (
RedeemCode,
RedeemCodeBatch,
RedeemCodeUsageLog,
RedeemCodeStatus,
)
from app.repositories.base import BaseRepository
class RedeemCodeRepository(BaseRepository[RedeemCode]):
"""兑换码仓库"""
model = RedeemCode
async def get_by_code(self, code: str) -> RedeemCode | None:
"""
通过兑换码获取记录
Args:
code: 兑换码
Returns:
兑换码记录或 None
"""
# 标准化兑换码格式
normalized_code = code.strip().upper().replace(" ", "")
stmt = select(RedeemCode).where(RedeemCode.code == normalized_code)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_code_for_update(self, code: str) -> RedeemCode | None:
"""
通过兑换码获取记录(加行级锁)
用于兑换操作,防止并发兑换。
Args:
code: 兑换码
Returns:
兑换码记录或 None
"""
normalized_code = code.strip().upper().replace(" ", "")
stmt = (
select(RedeemCode)
.where(RedeemCode.code == normalized_code)
.with_for_update()
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_valid_code(self, code: str) -> RedeemCode | None:
"""
获取有效的兑换码
检查状态、使用次数和有效期。
Args:
code: 兑换码
Returns:
有效的兑换码或 None
"""
normalized_code = code.strip().upper().replace(" ", "")
now = datetime.now(timezone.utc)
stmt = (
select(RedeemCode)
.where(
and_(
RedeemCode.code == normalized_code,
RedeemCode.status == RedeemCodeStatus.ACTIVE,
RedeemCode.used_count < RedeemCode.max_uses,
or_(
RedeemCode.expires_at.is_(None),
RedeemCode.expires_at > now,
),
)
)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def mark_as_used(
self,
code: RedeemCode,
user_id: str,
) -> bool:
"""
标记兑换码已使用
更新使用计数和状态。
Args:
code: 兑换码记录
user_id: 使用者 ID
Returns:
是否成功
"""
now = datetime.now(timezone.utc)
new_used_count = code.used_count + 1
new_status = (
RedeemCodeStatus.USED
if new_used_count >= code.max_uses
else RedeemCodeStatus.ACTIVE
)
stmt = (
update(RedeemCode)
.where(
and_(
RedeemCode.id == code.id,
RedeemCode.used_count == code.used_count, # 乐观锁
)
)
.values(
used_count=new_used_count,
status=new_status,
used_by=user_id,
used_at=now,
)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
code.used_count = new_used_count
code.status = new_status
code.used_by = user_id
code.used_at = now
return True
return False
async def get_all_with_filters(
self,
*,
offset: int = 0,
limit: int = 20,
status: RedeemCodeStatus | None = None,
batch_id: str | None = None,
code_like: str | None = None,
created_after: datetime | None = None,
created_before: datetime | None = None,
) -> list[RedeemCode]:
"""
获取兑换码列表(支持过滤)
Args:
offset: 偏移量
limit: 限制数量
status: 状态过滤
batch_id: 批次 ID 过滤
code_like: 兑换码模糊匹配
created_after: 创建时间起始
created_before: 创建时间结束
Returns:
兑换码列表
"""
stmt = select(RedeemCode)
conditions = []
if status:
conditions.append(RedeemCode.status == status)
if batch_id:
conditions.append(RedeemCode.batch_id == batch_id)
if code_like:
conditions.append(RedeemCode.code.contains(code_like.upper()))
if created_after:
conditions.append(RedeemCode.created_at >= created_after)
if created_before:
conditions.append(RedeemCode.created_at <= created_before)
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = (
stmt
.order_by(RedeemCode.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count_with_filters(
self,
*,
status: RedeemCodeStatus | None = None,
batch_id: str | None = None,
code_like: str | None = None,
created_after: datetime | None = None,
created_before: datetime | None = None,
) -> int:
"""
统计兑换码数量(支持过滤)
"""
stmt = select(func.count()).select_from(RedeemCode)
conditions = []
if status:
conditions.append(RedeemCode.status == status)
if batch_id:
conditions.append(RedeemCode.batch_id == batch_id)
if code_like:
conditions.append(RedeemCode.code.contains(code_like.upper()))
if created_after:
conditions.append(RedeemCode.created_at >= created_after)
if created_before:
conditions.append(RedeemCode.created_at <= created_before)
if conditions:
stmt = stmt.where(and_(*conditions))
result = await self.session.execute(stmt)
return result.scalar() or 0
async def bulk_create(
self,
codes_data: list[dict[str, Any]],
) -> list[RedeemCode]:
"""
批量创建兑换码
Args:
codes_data: 兑换码数据列表
Returns:
创建的兑换码列表
"""
codes = [RedeemCode(**data) for data in codes_data]
self.session.add_all(codes)
await self.session.flush()
return codes
async def disable_code(self, code: RedeemCode) -> RedeemCode:
"""
禁用兑换码
Args:
code: 兑换码记录
Returns:
更新后的兑换码
"""
return await self.update(code, status=RedeemCodeStatus.DISABLED)
async def enable_code(self, code: RedeemCode) -> RedeemCode:
"""
启用兑换码
Args:
code: 兑换码记录
Returns:
更新后的兑换码
"""
# 只有禁用状态的可以重新启用
if code.status != RedeemCodeStatus.DISABLED:
return code
# 如果使用次数已满,改为已使用状态
if code.used_count >= code.max_uses:
return await self.update(code, status=RedeemCodeStatus.USED)
return await self.update(code, status=RedeemCodeStatus.ACTIVE)
class RedeemCodeBatchRepository(BaseRepository[RedeemCodeBatch]):
"""兑换码批次仓库"""
model = RedeemCodeBatch
async def get_all_batches(
self,
*,
offset: int = 0,
limit: int = 20,
) -> list[RedeemCodeBatch]:
"""
获取所有批次
Args:
offset: 偏移量
limit: 限制数量
Returns:
批次列表
"""
stmt = (
select(RedeemCodeBatch)
.order_by(RedeemCodeBatch.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def increment_used_count(self, batch_id: str) -> None:
"""
增加批次已使用计数
Args:
batch_id: 批次 ID
"""
stmt = (
update(RedeemCodeBatch)
.where(RedeemCodeBatch.id == batch_id)
.values(used_count=RedeemCodeBatch.used_count + 1)
)
await self.session.execute(stmt)
class RedeemCodeUsageLogRepository(BaseRepository[RedeemCodeUsageLog]):
"""兑换码使用日志仓库"""
model = RedeemCodeUsageLog
async def get_by_code_id(
self,
redeem_code_id: str,
*,
offset: int = 0,
limit: int = 20,
) -> list[RedeemCodeUsageLog]:
"""
获取兑换码的使用日志
Args:
redeem_code_id: 兑换码 ID
offset: 偏移量
limit: 限制数量
Returns:
使用日志列表
"""
stmt = (
select(RedeemCodeUsageLog)
.where(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
.order_by(RedeemCodeUsageLog.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_by_user_id(
self,
user_id: str,
*,
offset: int = 0,
limit: int = 20,
) -> list[RedeemCodeUsageLog]:
"""
获取用户的兑换日志
Args:
user_id: 用户 ID
offset: 偏移量
limit: 限制数量
Returns:
使用日志列表
"""
stmt = (
select(RedeemCodeUsageLog)
.where(RedeemCodeUsageLog.user_id == user_id)
.order_by(RedeemCodeUsageLog.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_all_with_filters(
self,
*,
offset: int = 0,
limit: int = 20,
redeem_code_id: str | None = None,
user_id: str | None = None,
code_like: str | None = None,
created_after: datetime | None = None,
created_before: datetime | None = None,
) -> list[RedeemCodeUsageLog]:
"""
获取使用日志(支持过滤)
"""
stmt = select(RedeemCodeUsageLog)
conditions = []
if redeem_code_id:
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
if user_id:
conditions.append(RedeemCodeUsageLog.user_id == user_id)
if code_like:
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
if created_after:
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
if created_before:
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = (
stmt
.order_by(RedeemCodeUsageLog.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count_with_filters(
self,
*,
redeem_code_id: str | None = None,
user_id: str | None = None,
code_like: str | None = None,
created_after: datetime | None = None,
created_before: datetime | None = None,
) -> int:
"""
统计使用日志数量
"""
stmt = select(func.count()).select_from(RedeemCodeUsageLog)
conditions = []
if redeem_code_id:
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
if user_id:
conditions.append(RedeemCodeUsageLog.user_id == user_id)
if code_like:
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
if created_after:
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
if created_before:
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
if conditions:
stmt = stmt.where(and_(*conditions))
result = await self.session.execute(stmt)
return result.scalar() or 0

141
app/repositories/user.py Normal file
View File

@@ -0,0 +1,141 @@
"""
用户仓库
处理用户相关的数据库操作。
"""
from sqlalchemy import or_, select
from app.models.user import User
from app.repositories.base import BaseRepository
class UserRepository(BaseRepository[User]):
"""用户数据仓库"""
model = User
async def get_by_username(self, username: str) -> User | None:
"""
通过用户名获取用户
Args:
username: 用户名
Returns:
用户对象或 None
"""
stmt = select(User).where(User.username == username.lower())
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_email(self, email: str) -> User | None:
"""
通过邮箱获取用户
Args:
email: 邮箱地址
Returns:
用户对象或 None
"""
stmt = select(User).where(User.email == email.lower())
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_username_or_email(self, identifier: str) -> User | None:
"""
通过用户名或邮箱获取用户
Args:
identifier: 用户名或邮箱
Returns:
用户对象或 None
"""
identifier_lower = identifier.lower()
stmt = select(User).where(
or_(
User.username == identifier_lower,
User.email == identifier_lower,
)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def exists_by_username(self, username: str) -> bool:
"""
检查用户名是否存在
Args:
username: 用户名
Returns:
是否存在
"""
user = await self.get_by_username(username)
return user is not None
async def exists_by_email(self, email: str) -> bool:
"""
检查邮箱是否存在
Args:
email: 邮箱地址
Returns:
是否存在
"""
if not email:
return False
user = await self.get_by_email(email)
return user is not None
async def get_by_oauth(
self,
provider: str,
oauth_user_id: str,
) -> User | None:
"""
通过 OAuth2 提供商和用户 ID 获取用户
Args:
provider: OAuth2 提供商标识
oauth_user_id: OAuth2 用户 ID
Returns:
用户对象或 None
"""
stmt = select(User).where(
User.oauth_provider == provider,
User.oauth_user_id == oauth_user_id,
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_active_users(
self,
*,
offset: int = 0,
limit: int = 100,
) -> list[User]:
"""
获取活跃用户列表
Args:
offset: 偏移量
limit: 限制数量
Returns:
活跃用户列表
"""
stmt = (
select(User)
.where(User.is_active == True) # noqa: E712
.offset(offset)
.limit(limit)
.order_by(User.created_at.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())

75
app/schemas/__init__.py Normal file
View File

@@ -0,0 +1,75 @@
"""Pydantic 数据模式"""
from app.schemas.auth import (
LoginRequest,
PasswordChangeRequest,
RefreshTokenRequest,
TokenResponse,
)
from app.schemas.oauth2 import (
OAuth2AuthorizeResponse,
OAuth2CallbackRequest,
OAuth2LoginResponse,
OAuth2TokenData,
OAuth2UserInfo,
)
from app.schemas.user import (
UserCreate,
UserResponse,
UserUpdate,
)
from app.schemas.balance import (
BalanceResponse,
TransactionResponse,
DeductionRequest,
DeductionResponse,
AdminAdjustmentRequest,
AdminBalanceResponse,
)
from app.schemas.redeem_code import (
RedeemRequest,
RedeemResponse,
RedeemCodeResponse,
BatchCreateRequest,
BatchResponse,
BulkImportRequest,
BulkImportResponse,
ExportResponse,
UsageLogResponse,
)
__all__ = [
# Auth
"LoginRequest",
"TokenResponse",
"RefreshTokenRequest",
"PasswordChangeRequest",
# OAuth2
"OAuth2AuthorizeResponse",
"OAuth2CallbackRequest",
"OAuth2LoginResponse",
"OAuth2TokenData",
"OAuth2UserInfo",
# User
"UserCreate",
"UserResponse",
"UserUpdate",
# Balance
"BalanceResponse",
"TransactionResponse",
"DeductionRequest",
"DeductionResponse",
"AdminAdjustmentRequest",
"AdminBalanceResponse",
# Redeem Code
"RedeemRequest",
"RedeemResponse",
"RedeemCodeResponse",
"BatchCreateRequest",
"BatchResponse",
"BulkImportRequest",
"BulkImportResponse",
"ExportResponse",
"UsageLogResponse",
]

112
app/schemas/auth.py Normal file
View File

@@ -0,0 +1,112 @@
"""
认证相关 Schema
定义登录、令牌等数据结构。
"""
from typing import Annotated
from pydantic import Field
from app.schemas.base import BaseSchema
class LoginRequest(BaseSchema):
"""登录请求"""
username: Annotated[
str,
Field(
min_length=1,
description="用户名或邮箱",
examples=["john_doe"],
),
]
password: Annotated[
str,
Field(
min_length=1,
description="密码",
examples=["SecurePass123"],
),
]
class TokenResponse(BaseSchema):
"""令牌响应"""
access_token: str = Field(description="访问令牌")
refresh_token: str = Field(description="刷新令牌")
token_type: str = Field(default="Bearer", description="令牌类型")
expires_in: int = Field(description="访问令牌过期时间(秒)")
class RefreshTokenRequest(BaseSchema):
"""刷新令牌请求"""
refresh_token: Annotated[
str,
Field(
min_length=1,
description="刷新令牌",
),
]
class PasswordChangeRequest(BaseSchema):
"""修改密码请求"""
current_password: Annotated[
str,
Field(
min_length=1,
description="当前密码",
),
]
new_password: Annotated[
str,
Field(
min_length=8,
max_length=128,
description="新密码",
),
]
def model_post_init(self, __context) -> None:
"""验证新旧密码不同"""
if self.current_password == self.new_password:
raise ValueError("新密码不能与当前密码相同")
class PasswordResetRequest(BaseSchema):
"""密码重置请求(忘记密码)"""
email: Annotated[
str,
Field(
min_length=1,
description="注册邮箱",
examples=["user@example.com"],
),
]
class PasswordResetConfirm(BaseSchema):
"""密码重置确认"""
token: Annotated[
str,
Field(
min_length=1,
description="重置令牌",
),
]
new_password: Annotated[
str,
Field(
min_length=8,
max_length=128,
description="新密码",
),
]

285
app/schemas/balance.py Normal file
View File

@@ -0,0 +1,285 @@
"""
余额相关 Schema
定义余额数据的验证和序列化规则。
设计说明:
- 内部存储使用整数单位units外部显示使用小数display
- 1.00 显示余额 = 1000 单位额度
"""
from datetime import datetime
from typing import Annotated
from pydantic import Field, field_validator, computed_field
from app.schemas.base import BaseSchema, PaginatedResponse
from app.models.balance import TransactionType, TransactionStatus
# ============================================================
# 余额单位转换工具
# ============================================================
UNITS_PER_DISPLAY = 1000 # 1.00 显示余额 = 1000 单位额度
def display_to_units(display_amount: float) -> int:
"""将显示金额转换为单位额度"""
return int(round(display_amount * UNITS_PER_DISPLAY))
def units_to_display(units: int) -> float:
"""将单位额度转换为显示金额"""
return units / UNITS_PER_DISPLAY
def format_display(units: int) -> str:
"""格式化显示金额2 位小数)"""
return f"{units / UNITS_PER_DISPLAY:.2f}"
# ============================================================
# 余额账户 Schema
# ============================================================
class BalanceResponse(BaseSchema):
"""余额信息响应"""
user_id: str
balance_units: Annotated[
int,
Field(description="当前余额(单位额度)"),
]
frozen_units: Annotated[
int,
Field(description="冻结余额(单位额度)"),
]
total_recharged_units: Annotated[
int,
Field(description="累计充值(单位额度)"),
]
total_consumed_units: Annotated[
int,
Field(description="累计消费(单位额度)"),
]
@computed_field
@property
def balance(self) -> str:
"""显示余额2 位小数)"""
return format_display(self.balance_units)
@computed_field
@property
def available_balance(self) -> str:
"""显示可用余额2 位小数)"""
return format_display(self.balance_units - self.frozen_units)
@computed_field
@property
def frozen_balance(self) -> str:
"""显示冻结余额2 位小数)"""
return format_display(self.frozen_units)
@computed_field
@property
def total_recharged(self) -> str:
"""显示累计充值2 位小数)"""
return format_display(self.total_recharged_units)
@computed_field
@property
def total_consumed(self) -> str:
"""显示累计消费2 位小数)"""
return format_display(self.total_consumed_units)
class BalanceSummaryResponse(BaseSchema):
"""余额简要信息响应(用于嵌入用户信息)"""
balance: str = Field(description="当前余额")
available_balance: str = Field(description="可用余额")
# ============================================================
# 交易记录 Schema
# ============================================================
class TransactionResponse(BaseSchema):
"""交易记录响应"""
id: str
transaction_type: TransactionType
status: TransactionStatus
amount_units: Annotated[
int,
Field(description="交易金额(单位额度,正数收入,负数支出)"),
]
balance_before_units: Annotated[
int,
Field(description="交易前余额(单位额度)"),
]
balance_after_units: Annotated[
int,
Field(description="交易后余额(单位额度)"),
]
reference_type: str | None
reference_id: str | None
description: str | None
created_at: datetime
@computed_field
@property
def amount(self) -> str:
"""显示交易金额带符号2 位小数)"""
return f"{self.amount_units / UNITS_PER_DISPLAY:+.2f}"
@computed_field
@property
def balance_before(self) -> str:
"""显示交易前余额2 位小数)"""
return format_display(self.balance_before_units)
@computed_field
@property
def balance_after(self) -> str:
"""显示交易后余额2 位小数)"""
return format_display(self.balance_after_units)
class TransactionListResponse(PaginatedResponse[TransactionResponse]):
"""交易记录列表响应"""
pass
# ============================================================
# 扣款请求 Schema
# ============================================================
class DeductionRequest(BaseSchema):
"""扣款请求"""
amount: Annotated[
float,
Field(
gt=0,
description="扣款金额(显示金额,如 1.00",
examples=[1.00, 0.50],
),
]
reference_type: Annotated[
str | None,
Field(
default=None,
max_length=64,
description="关联业务类型",
examples=["api_call", "service_fee"],
),
]
reference_id: Annotated[
str | None,
Field(
default=None,
max_length=64,
description="关联业务 ID",
),
]
description: Annotated[
str | None,
Field(
default=None,
max_length=255,
description="交易描述",
),
]
idempotency_key: Annotated[
str | None,
Field(
default=None,
max_length=64,
description="幂等键(防止重复扣款)",
),
]
@field_validator("amount")
@classmethod
def validate_amount(cls, v: float) -> float:
"""验证金额精度"""
# 最小精度 0.001(即 1 单位额度)
if round(v, 3) != v:
raise ValueError("金额精度不能超过 3 位小数")
return v
@property
def amount_units(self) -> int:
"""转换为单位额度"""
return display_to_units(self.amount)
class DeductionResponse(BaseSchema):
"""扣款响应"""
transaction_id: str
amount: str = Field(description="扣款金额")
balance_before: str = Field(description="扣款前余额")
balance_after: str = Field(description="扣款后余额")
# ============================================================
# 管理员操作 Schema
# ============================================================
class AdminAdjustmentRequest(BaseSchema):
"""管理员余额调整请求"""
user_id: Annotated[
str,
Field(description="目标用户 ID"),
]
amount: Annotated[
float,
Field(
description="调整金额(正数增加,负数减少)",
examples=[10.00, -5.00],
),
]
reason: Annotated[
str,
Field(
min_length=1,
max_length=255,
description="调整原因",
),
]
@field_validator("amount")
@classmethod
def validate_amount(cls, v: float) -> float:
"""验证金额精度"""
if round(v, 3) != v:
raise ValueError("金额精度不能超过 3 位小数")
if v == 0:
raise ValueError("调整金额不能为 0")
return v
@property
def amount_units(self) -> int:
"""转换为单位额度"""
return display_to_units(self.amount)
class AdminBalanceResponse(BaseSchema):
"""管理员查看的余额信息(包含更多细节)"""
user_id: str
username: str
balance: str
available_balance: str
frozen_balance: str
total_recharged: str
total_consumed: str
version: int
created_at: datetime
updated_at: datetime

102
app/schemas/base.py Normal file
View File

@@ -0,0 +1,102 @@
"""
基础 Schema 定义
定义通用的响应模式和基础配置。
"""
from datetime import datetime
from typing import Any, Generic, TypeVar
from pydantic import BaseModel, ConfigDict
DataT = TypeVar("DataT")
class BaseSchema(BaseModel):
"""基础 Schema 配置"""
model_config = ConfigDict(
from_attributes=True, # 支持从 ORM 对象创建
populate_by_name=True, # 支持字段别名
str_strip_whitespace=True, # 自动去除字符串首尾空白
)
class TimestampMixin(BaseModel):
"""时间戳混入"""
created_at: datetime
updated_at: datetime
class APIResponse(BaseModel, Generic[DataT]):
"""统一 API 响应格式"""
success: bool = True
message: str = "操作成功"
data: DataT | None = None
@classmethod
def ok(cls, data: DataT | None = None, message: str = "操作成功") -> "APIResponse[DataT]":
"""成功响应"""
return cls(success=True, message=message, data=data)
@classmethod
def error(cls, message: str = "操作失败", data: DataT | None = None) -> "APIResponse[DataT]":
"""错误响应"""
return cls(success=False, message=message, data=data)
class ErrorResponse(BaseModel):
"""错误响应"""
success: bool = False
message: str
code: str
details: dict[str, Any] = {}
class PaginationParams(BaseModel):
"""分页参数"""
page: int = 1
page_size: int = 20
@property
def offset(self) -> int:
"""计算偏移量"""
return (self.page - 1) * self.page_size
@property
def limit(self) -> int:
"""获取限制数量"""
return self.page_size
class PaginatedResponse(BaseModel, Generic[DataT]):
"""分页响应"""
items: list[DataT]
total: int
page: int
page_size: int
total_pages: int
@classmethod
def create(
cls,
items: list[DataT],
total: int,
page: int,
page_size: int,
) -> "PaginatedResponse[DataT]":
"""创建分页响应"""
total_pages = (total + page_size - 1) // page_size if page_size > 0 else 0
return cls(
items=items,
total=total,
page=page,
page_size=page_size,
total_pages=total_pages,
)

85
app/schemas/oauth2.py Normal file
View File

@@ -0,0 +1,85 @@
"""
OAuth2 相关 Schema
定义 OAuth2 认证流程的数据结构。
"""
from typing import Annotated
from pydantic import Field
from app.schemas.base import BaseSchema
class OAuth2AuthorizeResponse(BaseSchema):
"""OAuth2 授权 URL 响应"""
authorize_url: str = Field(description="OAuth2 授权页面 URL")
state: str = Field(description="防 CSRF 状态码")
class OAuth2CallbackRequest(BaseSchema):
"""OAuth2 回调请求"""
code: Annotated[
str,
Field(
min_length=1,
description="授权码",
),
]
state: Annotated[
str,
Field(
min_length=1,
description="状态码(用于验证请求合法性)",
),
]
class OAuth2TokenData(BaseSchema):
"""OAuth2 令牌数据(从 OAuth 提供商获取)"""
access_token: str
token_type: str = "Bearer"
expires_in: int | None = None
refresh_token: str | None = None
scope: str | None = None
class OAuth2UserInfo(BaseSchema):
"""
OAuth2 用户信息(从 Linux.do 获取)
示例响应:
{
"id": 1,
"username": "neo",
"name": "Neo",
"active": true,
"trust_level": 4,
"email": "u1@linux.do",
"avatar_url": "https://linux.do/xxxx",
"silenced": false
}
"""
id: int | str = Field(description="用户 ID")
username: str = Field(description="用户名")
name: str | None = Field(default=None, description="显示名称")
email: str | None = Field(default=None, description="邮箱")
avatar_url: str | None = Field(default=None, description="头像 URL")
active: bool = Field(default=True, description="是否激活")
trust_level: int | None = Field(default=None, description="信任等级")
silenced: bool = Field(default=False, description="是否被禁言")
class OAuth2LoginResponse(BaseSchema):
"""OAuth2 登录响应"""
access_token: str = Field(description="JWT 访问令牌")
refresh_token: str = Field(description="JWT 刷新令牌")
token_type: str = Field(default="Bearer", description="令牌类型")
expires_in: int = Field(description="访问令牌过期时间(秒)")
is_new_user: bool = Field(description="是否为新注册用户")

346
app/schemas/redeem_code.py Normal file
View File

@@ -0,0 +1,346 @@
"""
兑换码相关 Schema
定义兑换码数据的验证和序列化规则。
"""
from datetime import datetime
from typing import Annotated
from pydantic import Field, field_validator
from app.schemas.base import BaseSchema, PaginatedResponse
from app.schemas.balance import UNITS_PER_DISPLAY, display_to_units, format_display
from app.models.redeem_code import RedeemCodeStatus
# ============================================================
# 用户兑换 Schema
# ============================================================
class RedeemRequest(BaseSchema):
"""兑换请求"""
code: Annotated[
str,
Field(
min_length=1,
max_length=32,
description="兑换码",
examples=["ABCD-EFGH-JKLM-NPQR"],
),
]
@field_validator("code")
@classmethod
def normalize_code(cls, v: str) -> str:
"""标准化兑换码格式"""
# 移除空格,转大写
return v.strip().upper().replace(" ", "")
class RedeemResponse(BaseSchema):
"""兑换响应"""
success: bool = True
message: str = "兑换成功"
face_value: str = Field(description="兑换金额")
balance_before: str = Field(description="兑换前余额")
balance_after: str = Field(description="兑换后余额")
# ============================================================
# 兑换码信息 Schema
# ============================================================
class RedeemCodeResponse(BaseSchema):
"""兑换码信息响应"""
id: str
code: str
face_value_units: int = Field(description="面值(单位额度)")
status: RedeemCodeStatus
max_uses: int
used_count: int
expires_at: datetime | None
used_at: datetime | None
created_at: datetime
@property
def face_value(self) -> str:
"""显示面值2 位小数)"""
return format_display(self.face_value_units)
@property
def is_valid(self) -> bool:
"""是否有效"""
if self.status != RedeemCodeStatus.ACTIVE:
return False
if self.used_count >= self.max_uses:
return False
if self.expires_at and self.expires_at < datetime.now(self.expires_at.tzinfo):
return False
return True
class RedeemCodeDetailResponse(RedeemCodeResponse):
"""兑换码详细信息响应(管理员用)"""
batch_id: str | None
batch_name: str | None = None
remark: str | None
created_by: str | None
used_by: str | None
class RedeemCodeListResponse(PaginatedResponse[RedeemCodeDetailResponse]):
"""兑换码列表响应"""
pass
# ============================================================
# 批次 Schema
# ============================================================
class BatchCreateRequest(BaseSchema):
"""创建兑换码批次请求"""
name: Annotated[
str,
Field(
min_length=1,
max_length=128,
description="批次名称",
examples=["2024年1月活动"],
),
]
description: Annotated[
str | None,
Field(
default=None,
max_length=500,
description="批次描述",
),
]
face_value: Annotated[
float,
Field(
gt=0,
description="面值(显示金额)",
examples=[10.00, 50.00],
),
]
count: Annotated[
int,
Field(
gt=0,
le=10000,
description="生成数量",
examples=[100],
),
]
max_uses: Annotated[
int,
Field(
default=1,
gt=0,
le=100,
description="每个兑换码最大使用次数",
),
]
expires_at: Annotated[
datetime | None,
Field(
default=None,
description="过期时间",
),
]
@field_validator("face_value")
@classmethod
def validate_face_value(cls, v: float) -> float:
"""验证面值精度"""
if round(v, 3) != v:
raise ValueError("面值精度不能超过 3 位小数")
return v
@property
def face_value_units(self) -> int:
"""转换为单位额度"""
return display_to_units(self.face_value)
class BatchResponse(BaseSchema):
"""批次信息响应"""
id: str
name: str
description: str | None
face_value_units: int
total_count: int
used_count: int
created_by: str | None
created_at: datetime
@property
def face_value(self) -> str:
"""显示面值2 位小数)"""
return format_display(self.face_value_units)
@property
def unused_count(self) -> int:
"""未使用数量"""
return self.total_count - self.used_count
class BatchDetailResponse(BatchResponse):
"""批次详细信息响应"""
codes: list[RedeemCodeResponse] = []
class BatchListResponse(PaginatedResponse[BatchResponse]):
"""批次列表响应"""
pass
# ============================================================
# 导入导出 Schema
# ============================================================
class ImportCodeRequest(BaseSchema):
"""导入兑换码请求"""
code: Annotated[
str,
Field(
min_length=1,
max_length=32,
description="兑换码",
),
]
face_value: Annotated[
float,
Field(
gt=0,
description="面值",
),
]
max_uses: Annotated[
int,
Field(
default=1,
gt=0,
),
]
expires_at: datetime | None = None
remark: str | None = None
@field_validator("code")
@classmethod
def normalize_code(cls, v: str) -> str:
"""标准化兑换码"""
return v.strip().upper()
@property
def face_value_units(self) -> int:
"""转换为单位额度"""
return display_to_units(self.face_value)
class BulkImportRequest(BaseSchema):
"""批量导入兑换码请求"""
codes: Annotated[
list[ImportCodeRequest],
Field(
min_length=1,
max_length=1000,
description="兑换码列表",
),
]
batch_name: Annotated[
str | None,
Field(
default=None,
max_length=128,
description="批次名称(可选)",
),
]
class BulkImportResponse(BaseSchema):
"""批量导入响应"""
success_count: int
failed_count: int
failed_codes: list[str] = []
batch_id: str | None = None
class ExportCodeItem(BaseSchema):
"""导出兑换码条目"""
code: str
face_value: str
status: str
max_uses: int
used_count: int
expires_at: str | None
created_at: str
used_at: str | None
used_by: str | None
class ExportResponse(BaseSchema):
"""导出响应"""
total: int
codes: list[ExportCodeItem]
# ============================================================
# 使用日志 Schema
# ============================================================
class UsageLogResponse(BaseSchema):
"""兑换码使用日志响应"""
id: str
redeem_code_id: str
code_snapshot: str
user_id: str
username: str | None = None
face_value: str
ip_address: str | None
created_at: datetime
class UsageLogListResponse(PaginatedResponse[UsageLogResponse]):
"""使用日志列表响应"""
pass
# ============================================================
# 查询参数 Schema
# ============================================================
class RedeemCodeQueryParams(BaseSchema):
"""兑换码查询参数"""
status: RedeemCodeStatus | None = None
batch_id: str | None = None
code: str | None = None
created_after: datetime | None = None
created_before: datetime | None = None
class UsageLogQueryParams(BaseSchema):
"""使用日志查询参数"""
redeem_code_id: str | None = None
user_id: str | None = None
code: str | None = None
created_after: datetime | None = None
created_before: datetime | None = None

156
app/schemas/user.py Normal file
View File

@@ -0,0 +1,156 @@
"""
用户相关 Schema
定义用户数据的验证和序列化规则。
"""
import re
from datetime import datetime
from typing import Annotated
from pydantic import EmailStr, Field, field_validator
from app.core.config import settings
from app.schemas.base import BaseSchema
# 用户名正则:字母开头,只允许字母、数字、下划线
USERNAME_PATTERN = re.compile(r"^[a-zA-Z][a-zA-Z0-9_]*$")
class UserBase(BaseSchema):
"""用户基础字段"""
username: Annotated[
str,
Field(
min_length=settings.username_min_length,
max_length=settings.username_max_length,
description="用户名(字母开头,只允许字母、数字、下划线)",
examples=["john_doe"],
),
]
email: Annotated[
EmailStr | None,
Field(
default=None,
description="邮箱地址",
examples=["user@example.com"],
),
]
nickname: Annotated[
str | None,
Field(
default=None,
max_length=64,
description="昵称",
examples=["John"],
),
]
@field_validator("username")
@classmethod
def validate_username(cls, v: str) -> str:
"""验证用户名格式"""
if not USERNAME_PATTERN.match(v):
raise ValueError("用户名必须以字母开头,只能包含字母、数字和下划线")
return v.lower() # 统一转小写
class UserCreate(UserBase):
"""用户注册请求"""
password: Annotated[
str,
Field(
min_length=settings.password_min_length,
max_length=settings.password_max_length,
description="密码",
examples=["SecurePass123"],
),
]
@field_validator("password")
@classmethod
def validate_password(cls, v: str) -> str:
"""验证密码强度"""
errors: list[str] = []
if settings.password_require_uppercase and not re.search(r"[A-Z]", v):
errors.append("至少包含一个大写字母")
if settings.password_require_lowercase and not re.search(r"[a-z]", v):
errors.append("至少包含一个小写字母")
if settings.password_require_digit and not re.search(r"\d", v):
errors.append("至少包含一个数字")
if settings.password_require_special and not re.search(r"[!@#$%^&*(),.?\":{}|<>]", v):
errors.append("至少包含一个特殊字符")
if errors:
raise ValueError("密码强度不足:" + "".join(errors))
return v
class UserUpdate(BaseSchema):
"""用户信息更新请求"""
nickname: Annotated[
str | None,
Field(
default=None,
max_length=64,
description="昵称",
),
]
email: Annotated[
EmailStr | None,
Field(
default=None,
description="邮箱地址",
),
]
avatar_url: Annotated[
str | None,
Field(
default=None,
max_length=512,
description="头像 URL",
),
]
bio: Annotated[
str | None,
Field(
default=None,
max_length=500,
description="个人简介",
),
]
class UserResponse(BaseSchema):
"""用户信息响应"""
id: str
username: str
email: str | None
nickname: str | None
avatar_url: str | None
bio: str | None
is_active: bool
created_at: datetime
last_login_at: datetime | None
class UserProfileResponse(BaseSchema):
"""用户公开资料响应(不包含敏感信息)"""
id: str
username: str
nickname: str | None
avatar_url: str | None
bio: str | None
created_at: datetime

16
app/services/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
"""业务服务层"""
from app.services.auth import AuthService
from app.services.oauth2 import OAuth2Service
from app.services.user import UserService
from app.services.balance import BalanceService
from app.services.redeem_code import RedeemCodeService
__all__ = [
"AuthService",
"OAuth2Service",
"UserService",
"BalanceService",
"RedeemCodeService",
]

296
app/services/auth.py Normal file
View File

@@ -0,0 +1,296 @@
"""
认证服务
处理用户认证相关的业务逻辑。
"""
from datetime import timedelta
import jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.exceptions import (
InvalidCredentialsError,
PasswordValidationError,
TokenError,
TokenExpiredError,
UserDisabledError,
UserNotFoundError,
)
from app.core.security import (
create_access_token,
create_refresh_token,
decode_token,
hash_password,
password_needs_rehash,
verify_password,
)
from app.models.user import User
from app.repositories.user import UserRepository
from app.schemas.auth import PasswordChangeRequest, TokenResponse
from app.services.user import UserService
class AuthService:
"""认证服务"""
def __init__(self, session: AsyncSession):
"""
初始化认证服务
Args:
session: 数据库会话
"""
self.session = session
self.user_repo = UserRepository(session)
self.user_service = UserService(session)
async def authenticate(
self,
username: str,
password: str,
) -> User:
"""
验证用户凭证
Args:
username: 用户名或邮箱
password: 密码
Returns:
验证成功的用户对象
Raises:
InvalidCredentialsError: 凭证无效
UserDisabledError: 用户被禁用
"""
# 查找用户(支持用户名或邮箱登录)
user = await self.user_repo.get_by_username_or_email(username)
if not user:
# 防止时序攻击:即使用户不存在也进行密码验证
verify_password(password, "$argon2id$v=19$m=65536,t=3,p=4$dummy$dummy")
raise InvalidCredentialsError()
# 验证密码
if not verify_password(password, user.hashed_password):
raise InvalidCredentialsError()
# 检查用户状态
if not user.is_active:
raise UserDisabledError()
# 检查是否需要重新哈希密码(参数升级)
if password_needs_rehash(user.hashed_password):
await self.user_repo.update(
user,
hashed_password=hash_password(password),
)
await self.user_repo.commit()
return user
async def login(
self,
username: str,
password: str,
) -> tuple[User, TokenResponse]:
"""
用户登录
Args:
username: 用户名或邮箱
password: 密码
Returns:
(用户对象, 令牌响应)
"""
user = await self.authenticate(username, password)
# 更新最后登录时间
await self.user_service.update_last_login(user)
# 生成令牌
tokens = self._create_tokens(user)
return user, tokens
def _create_tokens(self, user: User) -> TokenResponse:
"""
为用户创建访问令牌和刷新令牌
Args:
user: 用户对象
Returns:
令牌响应
"""
access_token = create_access_token(
subject=user.id,
extra_claims={
"username": user.username,
"is_superuser": user.is_superuser,
},
)
refresh_token = create_refresh_token(subject=user.id)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
)
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
"""
刷新访问令牌
Args:
refresh_token: 刷新令牌
Returns:
新的令牌响应
Raises:
TokenError: 令牌无效
TokenExpiredError: 令牌已过期
"""
try:
payload = decode_token(refresh_token)
except jwt.ExpiredSignatureError:
raise TokenExpiredError()
except jwt.InvalidTokenError:
raise TokenError()
# 验证令牌类型
if payload.get("type") != "refresh":
raise TokenError("无效的令牌类型")
# 获取用户
user_id = payload.get("sub")
if not user_id:
raise TokenError()
user = await self.user_repo.get_by_id(user_id)
if not user:
raise UserNotFoundError(user_id)
if not user.is_active:
raise UserDisabledError()
# 生成新令牌
return self._create_tokens(user)
async def change_password(
self,
user_id: str,
password_data: PasswordChangeRequest,
) -> None:
"""
修改用户密码
Args:
user_id: 用户 ID
password_data: 密码修改数据
Raises:
UserNotFoundError: 用户不存在
InvalidCredentialsError: 当前密码错误
PasswordValidationError: 新密码不符合要求
"""
user = await self.user_repo.get_by_id(user_id)
if not user:
raise UserNotFoundError(user_id)
# 验证当前密码
if not verify_password(password_data.current_password, user.hashed_password):
raise InvalidCredentialsError("当前密码错误")
# 验证新密码强度
self._validate_password_strength(password_data.new_password)
# 更新密码
await self.user_repo.update(
user,
hashed_password=hash_password(password_data.new_password),
)
await self.user_repo.commit()
def _validate_password_strength(self, password: str) -> None:
"""
验证密码强度
Args:
password: 密码
Raises:
PasswordValidationError: 密码不符合要求
"""
import re
errors: list[str] = []
if len(password) < settings.password_min_length:
errors.append(f"密码长度不能少于 {settings.password_min_length}")
if len(password) > settings.password_max_length:
errors.append(f"密码长度不能超过 {settings.password_max_length}")
if settings.password_require_uppercase and not re.search(r"[A-Z]", password):
errors.append("至少包含一个大写字母")
if settings.password_require_lowercase and not re.search(r"[a-z]", password):
errors.append("至少包含一个小写字母")
if settings.password_require_digit and not re.search(r"\d", password):
errors.append("至少包含一个数字")
if settings.password_require_special and not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
errors.append("至少包含一个特殊字符")
if errors:
raise PasswordValidationError("".join(errors))
async def get_current_user(self, token: str) -> User:
"""
从令牌获取当前用户
Args:
token: 访问令牌
Returns:
用户对象
Raises:
TokenError: 令牌无效
TokenExpiredError: 令牌已过期
UserNotFoundError: 用户不存在
UserDisabledError: 用户被禁用
"""
try:
payload = decode_token(token)
except jwt.ExpiredSignatureError:
raise TokenExpiredError()
except jwt.InvalidTokenError:
raise TokenError()
# 验证令牌类型
if payload.get("type") != "access":
raise TokenError("无效的令牌类型")
user_id = payload.get("sub")
if not user_id:
raise TokenError()
user = await self.user_repo.get_by_id(user_id)
if not user:
raise UserNotFoundError(user_id)
if not user.is_active:
raise UserDisabledError()
return user

934
app/services/balance.py Normal file
View File

@@ -0,0 +1,934 @@
"""
余额服务
处理余额相关的业务逻辑。
设计说明:
- 所有金额操作使用整数单位units避免浮点精度问题
- 扣款操作使用行级锁(悲观锁)确保原子性
- 充值操作使用乐观锁,配合重试机制
- 每笔操作都记录交易流水
预扣款流程(内部方法,用于耗时付费操作):
1. pre_authorize() - 预扣款冻结金额快速释放锁返回交易ID
2. 执行耗时的付费操作使用交易ID追踪
3. confirm() 或 cancel() - 根据操作结果确认或取消
推荐使用上下文管理器 deduction_context() 自动处理确认/取消。
"""
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any, AsyncIterator, Callable, Awaitable, TypeVar
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import (
AppException,
ResourceNotFoundError,
ValidationError,
)
from app.models.balance import (
UserBalance,
BalanceTransaction,
TransactionType,
TransactionStatus,
)
from app.repositories.balance import BalanceRepository, TransactionRepository
logger = logging.getLogger(__name__)
T = TypeVar("T")
class InsufficientBalanceError(AppException):
"""余额不足"""
def __init__(self, required: int, available: int):
super().__init__(
f"余额不足,需要 {required / 1000:.2f},当前可用 {available / 1000:.2f}",
"INSUFFICIENT_BALANCE",
{"required_units": required, "available_units": available},
)
class DuplicateTransactionError(AppException):
"""重复交易"""
def __init__(self, idempotency_key: str):
super().__init__(
"该交易已处理",
"DUPLICATE_TRANSACTION",
{"idempotency_key": idempotency_key},
)
class ConcurrencyError(AppException):
"""并发冲突"""
def __init__(self):
super().__init__(
"操作冲突,请重试",
"CONCURRENCY_ERROR",
)
class TransactionNotFoundError(AppException):
"""交易不存在"""
def __init__(self, transaction_id: str):
super().__init__(
"交易记录不存在",
"TRANSACTION_NOT_FOUND",
{"transaction_id": transaction_id},
)
class TransactionStateError(AppException):
"""交易状态错误"""
def __init__(self, transaction_id: str, current_status: str, expected_status: str = "pending"):
super().__init__(
f"交易状态无效:当前 {current_status},预期 {expected_status}",
"TRANSACTION_STATE_ERROR",
{
"transaction_id": transaction_id,
"current_status": current_status,
"expected_status": expected_status,
},
)
@dataclass
class PreAuthResult:
"""
预授权结果
包含交易ID和相关信息用于后续确认或取消操作。
"""
transaction_id: str
"""交易ID用于后续 confirm/cancel 操作"""
user_id: str
"""用户ID"""
amount_units: int
"""预扣款金额(单位额度)"""
frozen_at: datetime
"""冻结时间"""
@property
def amount_display(self) -> str:
"""显示金额2位小数"""
return f"{self.amount_units / 1000:.2f}"
@dataclass
class DeductionResult:
"""
扣款结果
包含扣款操作的完整信息。
"""
transaction_id: str
"""交易ID"""
status: TransactionStatus
"""交易状态"""
amount_units: int
"""实际扣款金额(单位额度)"""
balance_before: int
"""扣款前余额"""
balance_after: int
"""扣款后余额"""
@property
def success(self) -> bool:
"""是否扣款成功"""
return self.status == TransactionStatus.COMPLETED
@property
def amount_display(self) -> str:
"""显示金额"""
return f"{abs(self.amount_units) / 1000:.2f}"
@property
def balance_before_display(self) -> str:
"""显示扣款前余额"""
return f"{self.balance_before / 1000:.2f}"
@property
def balance_after_display(self) -> str:
"""显示扣款后余额"""
return f"{self.balance_after / 1000:.2f}"
class BalanceService:
"""余额服务"""
# 乐观锁最大重试次数
MAX_RETRIES = 3
def __init__(self, session: AsyncSession):
"""
初始化余额服务
Args:
session: 数据库会话
"""
self.session = session
self.balance_repo = BalanceRepository(session)
self.transaction_repo = TransactionRepository(session)
# ============================================================
# 余额查询
# ============================================================
async def get_balance(self, user_id: str) -> UserBalance:
"""
获取用户余额
如果用户没有余额账户,自动创建一个。
Args:
user_id: 用户 ID
Returns:
余额账户
"""
balance = await self.balance_repo.get_or_create(user_id)
await self.balance_repo.commit()
return balance
async def get_balance_detail(self, user_id: str) -> dict[str, Any]:
"""
获取用户余额详情
Args:
user_id: 用户 ID
Returns:
余额详情字典
"""
balance = await self.get_balance(user_id)
return {
"user_id": balance.user_id,
"balance_units": balance.balance,
"frozen_units": balance.frozen_balance,
"available_units": balance.available_balance,
"total_recharged_units": balance.total_recharged,
"total_consumed_units": balance.total_consumed,
}
async def get_transactions(
self,
user_id: str,
*,
offset: int = 0,
limit: int = 20,
transaction_type: TransactionType | None = None,
) -> tuple[list[BalanceTransaction], int]:
"""
获取用户交易记录
Args:
user_id: 用户 ID
offset: 偏移量
limit: 限制数量
transaction_type: 交易类型过滤
Returns:
(交易记录列表, 总数)
"""
transactions = await self.transaction_repo.get_by_user_id(
user_id,
offset=offset,
limit=limit,
transaction_type=transaction_type,
)
total = await self.transaction_repo.count_by_user_id(
user_id,
transaction_type=transaction_type,
)
return transactions, total
# ============================================================
# 扣款操作(使用行级锁 - 悲观锁)
# ============================================================
async def deduct(
self,
user_id: str,
amount_units: int,
*,
reference_type: str | None = None,
reference_id: str | None = None,
description: str | None = None,
idempotency_key: str | None = None,
) -> BalanceTransaction:
"""
扣款
使用行级锁确保原子性,防止并发扣款导致余额变负。
Args:
user_id: 用户 ID
amount_units: 扣款金额(单位额度,正数)
reference_type: 关联业务类型
reference_id: 关联业务 ID
description: 交易描述
idempotency_key: 幂等键
Returns:
交易记录
Raises:
InsufficientBalanceError: 余额不足
DuplicateTransactionError: 重复交易
"""
if amount_units <= 0:
raise ValidationError("扣款金额必须大于 0")
# 检查幂等性
if idempotency_key:
existing = await self.transaction_repo.get_by_idempotency_key(
idempotency_key
)
if existing:
raise DuplicateTransactionError(idempotency_key)
# 获取余额账户并加锁
balance = await self.balance_repo.get_or_create_for_update(user_id)
# 检查可用余额
if balance.available_balance < amount_units:
raise InsufficientBalanceError(amount_units, balance.available_balance)
# 记录扣款前余额
balance_before = balance.balance
# 执行扣款
balance.balance -= amount_units
balance.total_consumed += amount_units
balance.version += 1
# 创建交易记录
transaction = await self.transaction_repo.create(
user_id=user_id,
balance_account_id=balance.id,
transaction_type=TransactionType.DEDUCTION,
status=TransactionStatus.COMPLETED,
amount=-amount_units, # 负数表示支出
balance_before=balance_before,
balance_after=balance.balance,
reference_type=reference_type,
reference_id=reference_id,
description=description,
idempotency_key=idempotency_key,
)
await self.balance_repo.commit()
logger.info(
f"用户 {user_id} 扣款成功: {amount_units} 单位, "
f"余额 {balance_before} -> {balance.balance}"
)
return transaction
# ============================================================
# 预扣款流程(内部方法,用于耗时付费操作)
# ============================================================
async def pre_authorize(
self,
user_id: str,
amount_units: int,
*,
reference_type: str | None = None,
reference_id: str | None = None,
description: str | None = None,
) -> PreAuthResult:
"""
预授权扣款(内部方法)
冻结指定金额快速释放数据库锁返回交易ID供后续操作使用。
此方法设计用于耗时的付费操作场景。
使用流程:
1. 调用 pre_authorize() 获取 PreAuthResult
2. 执行可能失败的耗时操作
3. 根据操作结果调用 confirm() 或 cancel()
推荐使用 deduction_context() 上下文管理器自动处理。
Args:
user_id: 用户 ID
amount_units: 扣款金额(单位额度,正数)
reference_type: 关联业务类型(如 api_call, service
reference_id: 关联业务 ID
description: 交易描述
Returns:
PreAuthResult: 预授权结果包含交易ID
Raises:
InsufficientBalanceError: 余额不足
ValidationError: 参数无效
"""
if amount_units <= 0:
raise ValidationError("预扣款金额必须大于 0")
# 获取余额账户并加锁(短暂持有)
balance = await self.balance_repo.get_or_create_for_update(user_id)
# 检查可用余额
if balance.available_balance < amount_units:
raise InsufficientBalanceError(amount_units, balance.available_balance)
now = datetime.now(timezone.utc)
# 执行冻结
balance.frozen_balance += amount_units
balance.version += 1
# 创建待处理交易记录
transaction = await self.transaction_repo.create(
user_id=user_id,
balance_account_id=balance.id,
transaction_type=TransactionType.DEDUCTION,
status=TransactionStatus.PENDING,
amount=-amount_units,
balance_before=balance.balance,
balance_after=balance.balance, # 尚未实际扣款
reference_type=reference_type,
reference_id=reference_id,
description=description,
remark=f"预授权冻结: {amount_units} 单位",
)
# 快速提交释放锁
await self.balance_repo.commit()
logger.info(
f"用户 {user_id} 预授权成功: {amount_units} 单位, "
f"交易ID: {transaction.id}"
)
return PreAuthResult(
transaction_id=transaction.id,
user_id=user_id,
amount_units=amount_units,
frozen_at=now,
)
async def confirm(
self,
transaction_id: str,
*,
actual_amount_units: int | None = None,
) -> DeductionResult:
"""
确认预授权扣款(内部方法)
将预冻结的金额实际扣除。支持部分扣款。
Args:
transaction_id: 预授权交易 ID
actual_amount_units: 实际扣款金额(可选,用于部分扣款,默认全额)
Returns:
DeductionResult: 扣款结果
Raises:
TransactionNotFoundError: 交易不存在
TransactionStateError: 交易状态不是 PENDING
ValidationError: 参数无效
"""
transaction = await self.transaction_repo.get_by_id(transaction_id)
if not transaction:
raise TransactionNotFoundError(transaction_id)
if transaction.status != TransactionStatus.PENDING:
raise TransactionStateError(
transaction_id,
transaction.status.value,
)
# 获取余额账户并加锁
balance = await self.balance_repo.get_by_user_id_for_update(transaction.user_id)
if not balance:
raise ResourceNotFoundError("余额账户不存在")
frozen_amount = abs(transaction.amount)
# 确定实际扣款金额
if actual_amount_units is not None:
if actual_amount_units <= 0:
raise ValidationError("实际扣款金额必须大于 0")
if actual_amount_units > frozen_amount:
raise ValidationError(
f"实际扣款金额 ({actual_amount_units}) 不能超过预授权金额 ({frozen_amount})"
)
deduct_amount = actual_amount_units
else:
deduct_amount = frozen_amount
# 检查冻结金额
if balance.frozen_balance < frozen_amount:
raise ValidationError("冻结金额不足,可能已被其他操作修改")
balance_before = balance.balance
# 执行扣款:解冻全部,扣除实际金额
balance.frozen_balance -= frozen_amount
balance.balance -= deduct_amount
balance.total_consumed += deduct_amount
balance.version += 1
# 更新交易记录
transaction.status = TransactionStatus.COMPLETED
transaction.amount = -deduct_amount # 更新为实际扣款金额
transaction.balance_after = balance.balance
await self.balance_repo.commit()
logger.info(
f"用户 {transaction.user_id} 确认扣款: {deduct_amount} 单位, "
f"余额 {balance_before} -> {balance.balance}"
)
return DeductionResult(
transaction_id=transaction.id,
status=TransactionStatus.COMPLETED,
amount_units=deduct_amount,
balance_before=balance_before,
balance_after=balance.balance,
)
async def cancel(
self,
transaction_id: str,
*,
reason: str | None = None,
) -> DeductionResult:
"""
取消预授权扣款(内部方法)
解冻预授权的金额,退回用户可用余额。
Args:
transaction_id: 预授权交易 ID
reason: 取消原因(可选,记录在日志中)
Returns:
DeductionResult: 取消结果
Raises:
TransactionNotFoundError: 交易不存在
TransactionStateError: 交易状态不是 PENDING
"""
transaction = await self.transaction_repo.get_by_id(transaction_id)
if not transaction:
raise TransactionNotFoundError(transaction_id)
if transaction.status != TransactionStatus.PENDING:
raise TransactionStateError(
transaction_id,
transaction.status.value,
)
# 获取余额账户并加锁
balance = await self.balance_repo.get_by_user_id_for_update(transaction.user_id)
if not balance:
raise ResourceNotFoundError("余额账户不存在")
frozen_amount = abs(transaction.amount)
# 解冻
balance.frozen_balance -= frozen_amount
balance.version += 1
# 更新交易记录
transaction.status = TransactionStatus.CANCELLED
if reason:
transaction.remark = f"{transaction.remark or ''}; 取消原因: {reason}"
await self.balance_repo.commit()
logger.info(
f"用户 {transaction.user_id} 取消预授权: {frozen_amount} 单位"
+ (f", 原因: {reason}" if reason else "")
)
return DeductionResult(
transaction_id=transaction.id,
status=TransactionStatus.CANCELLED,
amount_units=0, # 实际未扣款
balance_before=balance.balance,
balance_after=balance.balance,
)
@asynccontextmanager
async def deduction_context(
self,
user_id: str,
amount_units: int,
*,
reference_type: str | None = None,
reference_id: str | None = None,
description: str | None = None,
auto_cancel_on_error: bool = True,
) -> AsyncIterator[PreAuthResult]:
"""
扣款上下文管理器(推荐使用)
提供简便的预扣款流程,自动处理确认和取消。
异常时自动取消预授权,退回冻结金额。
用法示例:
```python
async with balance_service.deduction_context(
user_id,
1000, # 扣款金额(单位额度)
reference_type="api_call",
description="API调用费用",
) as pre_auth:
# pre_auth.transaction_id 可用于追踪
# 执行可能失败的耗时操作
result = await call_external_api()
if not result.success:
raise Exception("API 调用失败")
# 成功退出时自动确认扣款
# 异常退出时自动取消预授权(如果 auto_cancel_on_error=True
```
Args:
user_id: 用户 ID
amount_units: 扣款金额(单位额度)
reference_type: 关联业务类型
reference_id: 关联业务 ID
description: 交易描述
auto_cancel_on_error: 异常时是否自动取消(默认 True
Yields:
PreAuthResult: 预授权结果包含交易ID
Raises:
InsufficientBalanceError: 余额不足
"""
# 第一阶段:预授权
pre_auth = await self.pre_authorize(
user_id,
amount_units,
reference_type=reference_type,
reference_id=reference_id,
description=description,
)
try:
yield pre_auth
# 正常退出:确认扣款
await self.confirm(pre_auth.transaction_id)
except Exception as e:
# 异常退出:取消预授权
if auto_cancel_on_error:
try:
await self.cancel(
pre_auth.transaction_id,
reason=f"操作失败: {str(e)[:200]}",
)
except Exception as cancel_error:
logger.error(
f"取消预授权失败: {pre_auth.transaction_id}, "
f"错误: {cancel_error}"
)
raise
async def execute_with_deduction(
self,
user_id: str,
amount_units: int,
operation: Callable[[PreAuthResult], Awaitable[T]],
*,
reference_type: str | None = None,
reference_id: str | None = None,
description: str | None = None,
) -> tuple[DeductionResult, T]:
"""
执行带扣款的操作(函数式接口)
预扣款后执行指定操作,根据操作结果自动确认或取消。
用法示例:
```python
async def call_api(pre_auth: PreAuthResult):
return await external_api.call(
transaction_id=pre_auth.transaction_id,
amount=pre_auth.amount_display,
)
deduction_result, api_result = await balance_service.execute_with_deduction(
user_id,
1000,
call_api,
reference_type="api_call",
)
```
Args:
user_id: 用户 ID
amount_units: 扣款金额(单位额度)
operation: 要执行的异步操作,接收 PreAuthResult 参数
reference_type: 关联业务类型
reference_id: 关联业务 ID
description: 交易描述
Returns:
(DeductionResult, operation返回值): 扣款结果和操作结果
Raises:
InsufficientBalanceError: 余额不足
Exception: 操作抛出的异常(预授权会自动取消)
"""
pre_auth = await self.pre_authorize(
user_id,
amount_units,
reference_type=reference_type,
reference_id=reference_id,
description=description,
)
try:
# 执行操作
result = await operation(pre_auth)
# 成功:确认扣款
deduction_result = await self.confirm(pre_auth.transaction_id)
return deduction_result, result
except Exception as e:
# 失败:取消预授权
try:
await self.cancel(
pre_auth.transaction_id,
reason=f"操作失败: {str(e)[:200]}",
)
except Exception as cancel_error:
logger.error(
f"取消预授权失败: {pre_auth.transaction_id}, "
f"错误: {cancel_error}"
)
raise
# ============================================================
# 兼容方法(保留旧接口)
# ============================================================
async def deduct_with_freeze(
self,
user_id: str,
amount_units: int,
*,
reference_type: str | None = None,
reference_id: str | None = None,
description: str | None = None,
) -> str:
"""
冻结并预扣款(兼容方法,推荐使用 pre_authorize
Returns:
交易ID
"""
result = await self.pre_authorize(
user_id,
amount_units,
reference_type=reference_type,
reference_id=reference_id,
description=description,
)
return result.transaction_id
async def confirm_frozen_deduction(self, transaction_id: str) -> BalanceTransaction:
"""
确认冻结扣款(兼容方法,推荐使用 confirm
"""
await self.confirm(transaction_id)
transaction = await self.transaction_repo.get_by_id(transaction_id)
return transaction # type: ignore
async def cancel_frozen_deduction(self, transaction_id: str) -> BalanceTransaction:
"""
取消冻结扣款(兼容方法,推荐使用 cancel
"""
await self.cancel(transaction_id)
transaction = await self.transaction_repo.get_by_id(transaction_id)
return transaction # type: ignore
# ============================================================
# 充值操作(使用乐观锁 + 重试)
# ============================================================
async def recharge(
self,
user_id: str,
amount_units: int,
*,
reference_type: str | None = None,
reference_id: str | None = None,
description: str | None = None,
idempotency_key: str | None = None,
) -> BalanceTransaction:
"""
充值
使用乐观锁,配合重试机制处理并发冲突。
Args:
user_id: 用户 ID
amount_units: 充值金额(单位额度,正数)
reference_type: 关联业务类型
reference_id: 关联业务 ID
description: 交易描述
idempotency_key: 幂等键
Returns:
交易记录
Raises:
DuplicateTransactionError: 重复交易
ConcurrencyError: 并发冲突(重试失败)
"""
if amount_units <= 0:
raise ValidationError("充值金额必须大于 0")
# 检查幂等性
if idempotency_key:
existing = await self.transaction_repo.get_by_idempotency_key(
idempotency_key
)
if existing:
raise DuplicateTransactionError(idempotency_key)
# 乐观锁重试
for attempt in range(self.MAX_RETRIES):
balance = await self.balance_repo.get_or_create(user_id)
balance_before = balance.balance
# 尝试更新余额
success = await self.balance_repo.update_balance_optimistic(
balance,
amount_units,
is_recharge=True,
)
if success:
# 创建交易记录
transaction = await self.transaction_repo.create(
user_id=user_id,
balance_account_id=balance.id,
transaction_type=TransactionType.RECHARGE,
status=TransactionStatus.COMPLETED,
amount=amount_units, # 正数表示收入
balance_before=balance_before,
balance_after=balance.balance,
reference_type=reference_type,
reference_id=reference_id,
description=description,
idempotency_key=idempotency_key,
)
await self.balance_repo.commit()
logger.info(
f"用户 {user_id} 充值成功: {amount_units} 单位, "
f"余额 {balance_before} -> {balance.balance}"
)
return transaction
# 冲突,重试
logger.warning(
f"用户 {user_id} 充值冲突,重试 {attempt + 1}/{self.MAX_RETRIES}"
)
await self.balance_repo.rollback()
# 重试失败
raise ConcurrencyError()
# ============================================================
# 管理员操作
# ============================================================
async def admin_adjust(
self,
user_id: str,
amount_units: int,
*,
operator_id: str,
reason: str,
) -> BalanceTransaction:
"""
管理员调整余额
Args:
user_id: 目标用户 ID
amount_units: 调整金额(正数增加,负数减少)
operator_id: 操作人 ID
reason: 调整原因
Returns:
交易记录
Raises:
InsufficientBalanceError: 减少金额时余额不足
"""
if amount_units == 0:
raise ValidationError("调整金额不能为 0")
# 获取余额账户并加锁
balance = await self.balance_repo.get_or_create_for_update(user_id)
# 减少时检查余额
if amount_units < 0 and balance.available_balance < abs(amount_units):
raise InsufficientBalanceError(
abs(amount_units), balance.available_balance
)
balance_before = balance.balance
# 执行调整
balance.balance += amount_units
if amount_units > 0:
balance.total_recharged += amount_units
balance.version += 1
# 创建交易记录
transaction = await self.transaction_repo.create(
user_id=user_id,
balance_account_id=balance.id,
transaction_type=TransactionType.ADJUSTMENT,
status=TransactionStatus.COMPLETED,
amount=amount_units,
balance_before=balance_before,
balance_after=balance.balance,
description=reason,
operator_id=operator_id,
remark=f"管理员调整: {reason}",
)
await self.balance_repo.commit()
logger.info(
f"管理员 {operator_id} 调整用户 {user_id} 余额: {amount_units} 单位, "
f"原因: {reason}"
)
return transaction

395
app/services/oauth2.py Normal file
View File

@@ -0,0 +1,395 @@
"""
OAuth2 认证服务
处理 OAuth2 认证流程,支持主备端点自动切换。
当首选端点不可达时,自动回退到备用端点。
"""
import logging
import secrets
from urllib.parse import urlencode
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.exceptions import (
AuthenticationError,
ResourceConflictError,
)
from app.core.security import create_access_token, create_refresh_token
from app.models.user import User
from app.repositories.user import UserRepository
from app.schemas.auth import TokenResponse
from app.schemas.oauth2 import OAuth2TokenData, OAuth2UserInfo
logger = logging.getLogger(__name__)
# OAuth2 提供商标识
OAUTH_PROVIDER_LINUXDO = "linuxdo"
class OAuth2EndpointError(AuthenticationError):
"""OAuth2 端点错误"""
def __init__(self, message: str = "OAuth2 服务不可用"):
super().__init__(message, "OAUTH2_ENDPOINT_ERROR")
class OAuth2StateError(AuthenticationError):
"""OAuth2 状态验证错误"""
def __init__(self, message: str = "无效的状态码"):
super().__init__(message, "OAUTH2_STATE_ERROR")
class OAuth2Service:
"""
OAuth2 认证服务
特性:
- 支持主备端点自动切换
- 首选端点请求失败时自动回退到备用端点
- 状态码验证防止 CSRF 攻击
"""
# 存储状态码(生产环境应使用 Redis
_state_store: dict[str, bool] = {}
def __init__(self, session: AsyncSession):
"""
初始化 OAuth2 服务
Args:
session: 数据库会话
"""
self.session = session
self.user_repo = UserRepository(session)
# 端点配置
self._endpoints = {
"authorize": {
"primary": settings.oauth2_authorize_endpoint,
"reserve": settings.oauth2_authorize_endpoint_reserve,
},
"token": {
"primary": settings.oauth2_token_endpoint,
"reserve": settings.oauth2_token_endpoint_reserve,
},
"userinfo": {
"primary": settings.oauth2_user_info_endpoint,
"reserve": settings.oauth2_user_info_endpoint_reserve,
},
}
self._timeout = settings.oauth2_request_timeout
def generate_authorize_url(self, redirect_uri: str) -> tuple[str, str]:
"""
生成 OAuth2 授权 URL
Args:
redirect_uri: 回调 URL
Returns:
(授权 URL, 状态码)
"""
state = secrets.token_urlsafe(32)
self._state_store[state] = True # 存储状态码
params = {
"client_id": settings.oauth2_client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
"scope": "read", # 根据实际需要调整
}
# 使用首选授权端点
authorize_url = f"{self._endpoints['authorize']['primary']}?{urlencode(params)}"
return authorize_url, state
def validate_state(self, state: str) -> bool:
"""
验证状态码(防 CSRF
Args:
state: 状态码
Returns:
是否有效
"""
if state in self._state_store:
del self._state_store[state] # 使用后立即删除
return True
return False
async def _request_with_fallback(
self,
endpoint_type: str,
method: str,
**kwargs,
) -> httpx.Response:
"""
带回退的 HTTP 请求
首先尝试首选端点,失败后自动切换到备用端点。
Args:
endpoint_type: 端点类型token/userinfo
method: HTTP 方法
**kwargs: 请求参数
Returns:
响应对象
Raises:
OAuth2EndpointError: 所有端点都不可用
"""
endpoints = self._endpoints[endpoint_type]
last_error: Exception | None = None
for endpoint_name, url in [("primary", endpoints["primary"]), ("reserve", endpoints["reserve"])]:
try:
logger.debug(f"尝试 OAuth2 {endpoint_type} 端点 ({endpoint_name}): {url}")
async with httpx.AsyncClient(timeout=self._timeout) as client:
if method.upper() == "POST":
response = await client.post(url, **kwargs)
else:
response = await client.get(url, **kwargs)
# 检查 HTTP 状态
if response.status_code >= 500:
logger.warning(
f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) "
f"返回服务器错误: {response.status_code}"
)
continue
logger.info(f"OAuth2 {endpoint_type} 请求成功 ({endpoint_name})")
return response
except httpx.TimeoutException as e:
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 超时: {e}")
last_error = e
except httpx.ConnectError as e:
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 连接失败: {e}")
last_error = e
except Exception as e:
logger.error(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 请求异常: {e}")
last_error = e
# 所有端点都失败
error_msg = f"OAuth2 {endpoint_type} 服务不可用"
if last_error:
error_msg += f": {last_error}"
raise OAuth2EndpointError(error_msg)
async def exchange_code_for_token(
self,
code: str,
redirect_uri: str,
) -> OAuth2TokenData:
"""
用授权码换取访问令牌
Args:
code: 授权码
redirect_uri: 回调 URL必须与授权时一致
Returns:
OAuth2 令牌数据
Raises:
OAuth2EndpointError: 端点不可用
AuthenticationError: 换取令牌失败
"""
data = {
"grant_type": "authorization_code",
"client_id": settings.oauth2_client_id,
"client_secret": settings.oauth2_client_secret,
"code": code,
"redirect_uri": redirect_uri,
}
response = await self._request_with_fallback(
"token",
"POST",
data=data,
headers={"Accept": "application/json"},
)
if response.status_code != 200:
logger.error(f"OAuth2 token 响应错误: {response.status_code} - {response.text}")
raise AuthenticationError(
f"获取访问令牌失败: {response.status_code}",
"OAUTH2_TOKEN_ERROR",
)
token_data = response.json()
return OAuth2TokenData(**token_data)
async def get_user_info(self, access_token: str) -> OAuth2UserInfo:
"""
获取 OAuth2 用户信息
Args:
access_token: OAuth2 访问令牌
Returns:
用户信息
Raises:
OAuth2EndpointError: 端点不可用
AuthenticationError: 获取用户信息失败
"""
response = await self._request_with_fallback(
"userinfo",
"GET",
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
)
if response.status_code != 200:
logger.error(f"OAuth2 userinfo 响应错误: {response.status_code} - {response.text}")
raise AuthenticationError(
f"获取用户信息失败: {response.status_code}",
"OAUTH2_USERINFO_ERROR",
)
user_data = response.json()
return OAuth2UserInfo(**user_data)
async def authenticate(
self,
code: str,
state: str,
redirect_uri: str,
) -> tuple[User, TokenResponse, bool]:
"""
完整的 OAuth2 认证流程
1. 验证状态码
2. 用授权码换取令牌
3. 获取用户信息
4. 创建或更新用户
5. 生成 JWT 令牌
Args:
code: 授权码
state: 状态码
redirect_uri: 回调 URL
Returns:
(用户对象, JWT 令牌响应, 是否新用户)
Raises:
OAuth2StateError: 状态码无效
OAuth2EndpointError: OAuth2 服务不可用
AuthenticationError: 认证失败
"""
# 1. 验证状态码
if not self.validate_state(state):
raise OAuth2StateError()
# 2. 换取令牌
oauth_token = await self.exchange_code_for_token(code, redirect_uri)
# 3. 获取用户信息
oauth_user = await self.get_user_info(oauth_token.access_token)
# 4. 查找或创建用户
user, is_new_user = await self._get_or_create_user(oauth_user)
# 5. 生成 JWT 令牌
tokens = self._create_tokens(user)
return user, tokens, is_new_user
async def _get_or_create_user(
self,
oauth_user: OAuth2UserInfo,
) -> tuple[User, bool]:
"""
根据 OAuth2 用户信息获取或创建本地用户
Args:
oauth_user: OAuth2 用户信息
Returns:
(用户对象, 是否新创建)
"""
oauth_user_id = str(oauth_user.id)
# 先通过 OAuth ID 查找
user = await self.user_repo.get_by_oauth(
provider=OAUTH_PROVIDER_LINUXDO,
oauth_user_id=oauth_user_id,
)
if user:
# 更新用户信息(头像等可能变化)
await self.user_repo.update(
user,
nickname=oauth_user.name or oauth_user.username,
avatar_url=oauth_user.avatar_url,
)
await self.user_repo.commit()
return user, False
# 检查用户名是否已存在
username = oauth_user.username.lower()
existing_user = await self.user_repo.get_by_username(username)
if existing_user:
# 用户名冲突,添加后缀
username = f"{username}_{oauth_user_id[:8]}"
# 创建新用户
user = await self.user_repo.create(
username=username,
email=oauth_user.email,
nickname=oauth_user.name or oauth_user.username,
avatar_url=oauth_user.avatar_url,
oauth_provider=OAUTH_PROVIDER_LINUXDO,
oauth_user_id=oauth_user_id,
hashed_password=None, # OAuth2 用户无密码
is_active=oauth_user.active,
)
await self.user_repo.commit()
logger.info(f"创建 OAuth2 用户: {user.username} (provider={OAUTH_PROVIDER_LINUXDO})")
return user, True
def _create_tokens(self, user: User) -> TokenResponse:
"""
为用户创建 JWT 令牌
Args:
user: 用户对象
Returns:
令牌响应
"""
access_token = create_access_token(
subject=user.id,
extra_claims={
"username": user.username,
"is_superuser": user.is_superuser,
"oauth_provider": user.oauth_provider,
},
)
refresh_token = create_refresh_token(subject=user.id)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
)

570
app/services/redeem_code.py Normal file
View File

@@ -0,0 +1,570 @@
"""
兑换码服务
处理兑换码相关的业务逻辑。
设计说明:
- 兑换操作使用行级锁确保原子性
- 支持批量生成和导入导出
- 记录完整的使用日志
"""
import logging
from datetime import datetime, timezone
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import (
AppException,
ResourceNotFoundError,
ValidationError,
)
from app.models.redeem_code import (
RedeemCode,
RedeemCodeBatch,
RedeemCodeUsageLog,
RedeemCodeStatus,
generate_redeem_code,
)
from app.models.balance import TransactionType
from app.repositories.redeem_code import (
RedeemCodeRepository,
RedeemCodeBatchRepository,
RedeemCodeUsageLogRepository,
)
from app.services.balance import BalanceService
logger = logging.getLogger(__name__)
class RedeemCodeNotFoundError(AppException):
"""兑换码不存在"""
def __init__(self, code: str):
super().__init__(
"兑换码不存在",
"REDEEM_CODE_NOT_FOUND",
{"code": code},
)
class RedeemCodeInvalidError(AppException):
"""兑换码无效"""
def __init__(self, code: str, reason: str):
super().__init__(
f"兑换码无效: {reason}",
"REDEEM_CODE_INVALID",
{"code": code, "reason": reason},
)
class RedeemCodeExpiredError(AppException):
"""兑换码已过期"""
def __init__(self, code: str):
super().__init__(
"兑换码已过期",
"REDEEM_CODE_EXPIRED",
{"code": code},
)
class RedeemCodeUsedError(AppException):
"""兑换码已使用"""
def __init__(self, code: str):
super().__init__(
"兑换码已使用",
"REDEEM_CODE_USED",
{"code": code},
)
class RedeemCodeDisabledError(AppException):
"""兑换码已禁用"""
def __init__(self, code: str):
super().__init__(
"兑换码已禁用",
"REDEEM_CODE_DISABLED",
{"code": code},
)
class RedeemCodeService:
"""兑换码服务"""
def __init__(self, session: AsyncSession):
"""
初始化兑换码服务
Args:
session: 数据库会话
"""
self.session = session
self.code_repo = RedeemCodeRepository(session)
self.batch_repo = RedeemCodeBatchRepository(session)
self.log_repo = RedeemCodeUsageLogRepository(session)
self.balance_service = BalanceService(session)
# ============================================================
# 用户兑换
# ============================================================
async def redeem(
self,
user_id: str,
code: str,
*,
ip_address: str | None = None,
user_agent: str | None = None,
) -> dict[str, Any]:
"""
用户兑换余额
使用行级锁确保原子性,防止并发兑换。
Args:
user_id: 用户 ID
code: 兑换码
ip_address: 客户端 IP
user_agent: User Agent
Returns:
兑换结果
Raises:
RedeemCodeNotFoundError: 兑换码不存在
RedeemCodeInvalidError: 兑换码无效
"""
# 标准化兑换码
normalized_code = code.strip().upper().replace(" ", "")
# 获取兑换码并加锁
redeem_code = await self.code_repo.get_by_code_for_update(normalized_code)
if not redeem_code:
raise RedeemCodeNotFoundError(normalized_code)
# 验证兑换码状态
self._validate_redeem_code(redeem_code)
# 获取用户当前余额
balance = await self.balance_service.get_balance(user_id)
balance_before = balance.balance
# 执行充值
transaction = await self.balance_service.recharge(
user_id,
redeem_code.face_value,
reference_type="redeem_code",
reference_id=redeem_code.id,
description=f"兑换码充值: {redeem_code.code}",
)
# 标记兑换码已使用
await self.code_repo.mark_as_used(redeem_code, user_id)
# 更新批次统计
if redeem_code.batch_id:
await self.batch_repo.increment_used_count(redeem_code.batch_id)
# 记录使用日志
await self.log_repo.create(
redeem_code_id=redeem_code.id,
user_id=user_id,
transaction_id=transaction.id,
code_snapshot=redeem_code.code,
face_value=redeem_code.face_value,
ip_address=ip_address,
user_agent=user_agent,
)
await self.code_repo.commit()
logger.info(
f"用户 {user_id} 兑换成功: {redeem_code.code}, "
f"面值 {redeem_code.face_value} 单位"
)
return {
"success": True,
"message": "兑换成功",
"face_value": f"{redeem_code.face_value / 1000:.2f}",
"balance_before": f"{balance_before / 1000:.2f}",
"balance_after": f"{transaction.balance_after / 1000:.2f}",
}
def _validate_redeem_code(self, code: RedeemCode) -> None:
"""验证兑换码有效性"""
if code.status == RedeemCodeStatus.DISABLED:
raise RedeemCodeDisabledError(code.code)
if code.status == RedeemCodeStatus.USED or code.used_count >= code.max_uses:
raise RedeemCodeUsedError(code.code)
if code.expires_at and code.expires_at < datetime.now(timezone.utc):
raise RedeemCodeExpiredError(code.code)
# ============================================================
# 管理员:批量生成
# ============================================================
async def create_batch(
self,
name: str,
face_value_units: int,
count: int,
*,
created_by: str,
description: str | None = None,
max_uses: int = 1,
expires_at: datetime | None = None,
) -> RedeemCodeBatch:
"""
创建兑换码批次
批量生成指定数量的兑换码。
Args:
name: 批次名称
face_value_units: 面值(单位额度)
count: 生成数量
created_by: 创建者 ID
description: 批次描述
max_uses: 每个兑换码最大使用次数
expires_at: 过期时间
Returns:
创建的批次
"""
if face_value_units <= 0:
raise ValidationError("面值必须大于 0")
if count <= 0 or count > 10000:
raise ValidationError("数量必须在 1-10000 之间")
# 创建批次
batch = await self.batch_repo.create(
name=name,
description=description,
face_value=face_value_units,
total_count=count,
created_by=created_by,
)
# 批量生成兑换码
codes_data = []
generated_codes = set()
while len(codes_data) < count:
new_code = generate_redeem_code()
if new_code not in generated_codes:
generated_codes.add(new_code)
codes_data.append({
"code": new_code,
"batch_id": batch.id,
"face_value": face_value_units,
"max_uses": max_uses,
"expires_at": expires_at,
"created_by": created_by,
})
await self.code_repo.bulk_create(codes_data)
await self.batch_repo.commit()
logger.info(
f"管理员 {created_by} 创建批次 '{name}': "
f"{count} 个兑换码, 面值 {face_value_units} 单位"
)
return batch
# ============================================================
# 管理员:导入
# ============================================================
async def import_codes(
self,
codes: list[dict[str, Any]],
*,
created_by: str,
batch_name: str | None = None,
) -> dict[str, Any]:
"""
导入兑换码
Args:
codes: 兑换码数据列表
created_by: 创建者 ID
batch_name: 批次名称(可选)
Returns:
导入结果
"""
batch_id = None
# 创建批次(如果指定)
if batch_name:
# 计算总面值用于批次记录
total_face_value = sum(c.get("face_value_units", 0) for c in codes)
batch = await self.batch_repo.create(
name=batch_name,
description="导入批次",
face_value=total_face_value // len(codes) if codes else 0,
total_count=len(codes),
created_by=created_by,
)
batch_id = batch.id
success_count = 0
failed_codes = []
for code_data in codes:
try:
# 检查兑换码是否已存在
existing = await self.code_repo.get_by_code(code_data["code"])
if existing:
failed_codes.append(code_data["code"])
continue
# 创建兑换码
await self.code_repo.create(
code=code_data["code"].strip().upper(),
batch_id=batch_id,
face_value=code_data["face_value_units"],
max_uses=code_data.get("max_uses", 1),
expires_at=code_data.get("expires_at"),
remark=code_data.get("remark"),
created_by=created_by,
)
success_count += 1
except Exception as e:
logger.warning(f"导入兑换码失败: {code_data.get('code')}, {e}")
failed_codes.append(code_data.get("code", "unknown"))
await self.code_repo.commit()
logger.info(
f"管理员 {created_by} 导入兑换码: "
f"成功 {success_count}, 失败 {len(failed_codes)}"
)
return {
"success_count": success_count,
"failed_count": len(failed_codes),
"failed_codes": failed_codes,
"batch_id": batch_id,
}
# ============================================================
# 管理员:导出
# ============================================================
async def export_codes(
self,
*,
batch_id: str | None = None,
status: RedeemCodeStatus | None = None,
limit: int = 10000,
) -> list[dict[str, Any]]:
"""
导出兑换码
Args:
batch_id: 批次 ID 过滤
status: 状态过滤
limit: 最大导出数量
Returns:
兑换码数据列表
"""
codes = await self.code_repo.get_all_with_filters(
batch_id=batch_id,
status=status,
limit=limit,
)
result = []
for code in codes:
result.append({
"code": code.code,
"face_value": f"{code.face_value / 1000:.2f}",
"status": code.status.value,
"max_uses": code.max_uses,
"used_count": code.used_count,
"expires_at": code.expires_at.isoformat() if code.expires_at else None,
"created_at": code.created_at.isoformat(),
"used_at": code.used_at.isoformat() if code.used_at else None,
"used_by": code.used_by,
})
return result
# ============================================================
# 管理员:查询
# ============================================================
async def get_codes(
self,
*,
offset: int = 0,
limit: int = 20,
status: RedeemCodeStatus | None = None,
batch_id: str | None = None,
code_like: str | None = None,
created_after: datetime | None = None,
created_before: datetime | None = None,
) -> tuple[list[RedeemCode], int]:
"""
获取兑换码列表
Returns:
(兑换码列表, 总数)
"""
codes = await self.code_repo.get_all_with_filters(
offset=offset,
limit=limit,
status=status,
batch_id=batch_id,
code_like=code_like,
created_after=created_after,
created_before=created_before,
)
total = await self.code_repo.count_with_filters(
status=status,
batch_id=batch_id,
code_like=code_like,
created_after=created_after,
created_before=created_before,
)
return codes, total
async def get_code_detail(self, code_id: str) -> RedeemCode:
"""
获取兑换码详情
Args:
code_id: 兑换码 ID
Returns:
兑换码记录
"""
code = await self.code_repo.get_by_id(code_id)
if not code:
raise ResourceNotFoundError("兑换码不存在", "redeem_code", code_id)
return code
async def disable_code(self, code_id: str) -> RedeemCode:
"""
禁用兑换码
Args:
code_id: 兑换码 ID
Returns:
更新后的兑换码
"""
code = await self.get_code_detail(code_id)
code = await self.code_repo.disable_code(code)
await self.code_repo.commit()
logger.info(f"兑换码已禁用: {code.code}")
return code
async def enable_code(self, code_id: str) -> RedeemCode:
"""
启用兑换码
Args:
code_id: 兑换码 ID
Returns:
更新后的兑换码
"""
code = await self.get_code_detail(code_id)
code = await self.code_repo.enable_code(code)
await self.code_repo.commit()
logger.info(f"兑换码已启用: {code.code}")
return code
# ============================================================
# 管理员:批次管理
# ============================================================
async def get_batches(
self,
*,
offset: int = 0,
limit: int = 20,
) -> tuple[list[RedeemCodeBatch], int]:
"""
获取批次列表
Returns:
(批次列表, 总数)
"""
batches = await self.batch_repo.get_all_batches(
offset=offset,
limit=limit,
)
total = await self.batch_repo.count()
return batches, total
async def get_batch_detail(self, batch_id: str) -> RedeemCodeBatch:
"""
获取批次详情
Args:
batch_id: 批次 ID
Returns:
批次记录
"""
batch = await self.batch_repo.get_by_id(batch_id)
if not batch:
raise ResourceNotFoundError("批次不存在", "batch", batch_id)
return batch
# ============================================================
# 管理员:使用日志
# ============================================================
async def get_usage_logs(
self,
*,
offset: int = 0,
limit: int = 20,
redeem_code_id: str | None = None,
user_id: str | None = None,
code_like: str | None = None,
created_after: datetime | None = None,
created_before: datetime | None = None,
) -> tuple[list[RedeemCodeUsageLog], int]:
"""
获取使用日志
Returns:
(日志列表, 总数)
"""
logs = await self.log_repo.get_all_with_filters(
offset=offset,
limit=limit,
redeem_code_id=redeem_code_id,
user_id=user_id,
code_like=code_like,
created_after=created_after,
created_before=created_before,
)
total = await self.log_repo.count_with_filters(
redeem_code_id=redeem_code_id,
user_id=user_id,
code_like=code_like,
created_after=created_after,
created_before=created_before,
)
return logs, total

175
app/services/user.py Normal file
View File

@@ -0,0 +1,175 @@
"""
用户服务
处理用户相关的业务逻辑。
"""
from datetime import datetime, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.exceptions import (
UserAlreadyExistsError,
UserNotFoundError,
)
from app.core.security import hash_password
from app.models.user import User
from app.repositories.user import UserRepository
from app.schemas.user import UserCreate, UserUpdate
class UserService:
"""用户服务"""
def __init__(self, session: AsyncSession):
"""
初始化用户服务
Args:
session: 数据库会话
"""
self.session = session
self.user_repo = UserRepository(session)
async def create_user(self, user_data: UserCreate) -> User:
"""
创建新用户
Args:
user_data: 用户注册数据
Returns:
新创建的用户
Raises:
UserAlreadyExistsError: 用户名或邮箱已存在
"""
# 检查用户名是否已存在
if await self.user_repo.exists_by_username(user_data.username):
raise UserAlreadyExistsError("用户名")
# 检查邮箱是否已存在
if user_data.email and await self.user_repo.exists_by_email(user_data.email):
raise UserAlreadyExistsError("邮箱")
# 创建用户
user = await self.user_repo.create(
username=user_data.username.lower(),
email=user_data.email.lower() if user_data.email else None,
nickname=user_data.nickname,
hashed_password=hash_password(user_data.password),
)
await self.user_repo.commit()
return user
async def get_user_by_id(self, user_id: str) -> User:
"""
通过 ID 获取用户
Args:
user_id: 用户 ID
Returns:
用户对象
Raises:
UserNotFoundError: 用户不存在
"""
user = await self.user_repo.get_by_id(user_id)
if not user:
raise UserNotFoundError(user_id)
return user
async def get_user_by_username(self, username: str) -> User | None:
"""
通过用户名获取用户
Args:
username: 用户名
Returns:
用户对象或 None
"""
return await self.user_repo.get_by_username(username)
async def update_user(
self,
user_id: str,
update_data: UserUpdate,
) -> User:
"""
更新用户信息
Args:
user_id: 用户 ID
update_data: 更新数据
Returns:
更新后的用户
Raises:
UserNotFoundError: 用户不存在
UserAlreadyExistsError: 邮箱已被使用
"""
user = await self.get_user_by_id(user_id)
# 检查邮箱是否被其他用户使用
if update_data.email:
existing_user = await self.user_repo.get_by_email(update_data.email)
if existing_user and existing_user.id != user_id:
raise UserAlreadyExistsError("邮箱")
# 准备更新数据
update_dict = update_data.model_dump(exclude_unset=True)
if update_dict.get("email"):
update_dict["email"] = update_dict["email"].lower()
# 更新用户
user = await self.user_repo.update(user, **update_dict)
await self.user_repo.commit()
return user
async def update_last_login(self, user: User) -> None:
"""
更新用户最后登录时间
Args:
user: 用户对象
"""
await self.user_repo.update(
user,
last_login_at=datetime.now(timezone.utc),
)
await self.user_repo.commit()
async def deactivate_user(self, user_id: str) -> User:
"""
禁用用户账户
Args:
user_id: 用户 ID
Returns:
更新后的用户
"""
user = await self.get_user_by_id(user_id)
user = await self.user_repo.update(user, is_active=False)
await self.user_repo.commit()
return user
async def activate_user(self, user_id: str) -> User:
"""
激活用户账户
Args:
user_id: 用户 ID
Returns:
更新后的用户
"""
user = await self.get_user_by_id(user_id)
user = await self.user_repo.update(user, is_active=True)
await self.user_repo.commit()
return user