提供基本前后端骨架
This commit is contained in:
16
app/services/__init__.py
Normal file
16
app/services/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""业务服务层"""
|
||||
|
||||
from app.services.auth import AuthService
|
||||
from app.services.oauth2 import OAuth2Service
|
||||
from app.services.user import UserService
|
||||
from app.services.balance import BalanceService
|
||||
from app.services.redeem_code import RedeemCodeService
|
||||
|
||||
__all__ = [
|
||||
"AuthService",
|
||||
"OAuth2Service",
|
||||
"UserService",
|
||||
"BalanceService",
|
||||
"RedeemCodeService",
|
||||
]
|
||||
|
||||
296
app/services/auth.py
Normal file
296
app/services/auth.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
认证服务
|
||||
|
||||
处理用户认证相关的业务逻辑。
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import jwt
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
InvalidCredentialsError,
|
||||
PasswordValidationError,
|
||||
TokenError,
|
||||
TokenExpiredError,
|
||||
UserDisabledError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
password_needs_rehash,
|
||||
verify_password,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.auth import PasswordChangeRequest, TokenResponse
|
||||
from app.services.user import UserService
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""认证服务"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化认证服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
self.user_service = UserService(session)
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> User:
|
||||
"""
|
||||
验证用户凭证
|
||||
|
||||
Args:
|
||||
username: 用户名或邮箱
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
验证成功的用户对象
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: 凭证无效
|
||||
UserDisabledError: 用户被禁用
|
||||
"""
|
||||
# 查找用户(支持用户名或邮箱登录)
|
||||
user = await self.user_repo.get_by_username_or_email(username)
|
||||
|
||||
if not user:
|
||||
# 防止时序攻击:即使用户不存在也进行密码验证
|
||||
verify_password(password, "$argon2id$v=19$m=65536,t=3,p=4$dummy$dummy")
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
# 检查是否需要重新哈希密码(参数升级)
|
||||
if password_needs_rehash(user.hashed_password):
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
hashed_password=hash_password(password),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
return user
|
||||
|
||||
async def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> tuple[User, TokenResponse]:
|
||||
"""
|
||||
用户登录
|
||||
|
||||
Args:
|
||||
username: 用户名或邮箱
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
(用户对象, 令牌响应)
|
||||
"""
|
||||
user = await self.authenticate(username, password)
|
||||
|
||||
# 更新最后登录时间
|
||||
await self.user_service.update_last_login(user)
|
||||
|
||||
# 生成令牌
|
||||
tokens = self._create_tokens(user)
|
||||
|
||||
return user, tokens
|
||||
|
||||
def _create_tokens(self, user: User) -> TokenResponse:
|
||||
"""
|
||||
为用户创建访问令牌和刷新令牌
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
|
||||
Returns:
|
||||
令牌响应
|
||||
"""
|
||||
access_token = create_access_token(
|
||||
subject=user.id,
|
||||
extra_claims={
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
},
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(subject=user.id)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
|
||||
"""
|
||||
刷新访问令牌
|
||||
|
||||
Args:
|
||||
refresh_token: 刷新令牌
|
||||
|
||||
Returns:
|
||||
新的令牌响应
|
||||
|
||||
Raises:
|
||||
TokenError: 令牌无效
|
||||
TokenExpiredError: 令牌已过期
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(refresh_token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError()
|
||||
except jwt.InvalidTokenError:
|
||||
raise TokenError()
|
||||
|
||||
# 验证令牌类型
|
||||
if payload.get("type") != "refresh":
|
||||
raise TokenError("无效的令牌类型")
|
||||
|
||||
# 获取用户
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise TokenError()
|
||||
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
# 生成新令牌
|
||||
return self._create_tokens(user)
|
||||
|
||||
async def change_password(
|
||||
self,
|
||||
user_id: str,
|
||||
password_data: PasswordChangeRequest,
|
||||
) -> None:
|
||||
"""
|
||||
修改用户密码
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
password_data: 密码修改数据
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: 用户不存在
|
||||
InvalidCredentialsError: 当前密码错误
|
||||
PasswordValidationError: 新密码不符合要求
|
||||
"""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
# 验证当前密码
|
||||
if not verify_password(password_data.current_password, user.hashed_password):
|
||||
raise InvalidCredentialsError("当前密码错误")
|
||||
|
||||
# 验证新密码强度
|
||||
self._validate_password_strength(password_data.new_password)
|
||||
|
||||
# 更新密码
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
hashed_password=hash_password(password_data.new_password),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
def _validate_password_strength(self, password: str) -> None:
|
||||
"""
|
||||
验证密码强度
|
||||
|
||||
Args:
|
||||
password: 密码
|
||||
|
||||
Raises:
|
||||
PasswordValidationError: 密码不符合要求
|
||||
"""
|
||||
import re
|
||||
|
||||
errors: list[str] = []
|
||||
|
||||
if len(password) < settings.password_min_length:
|
||||
errors.append(f"密码长度不能少于 {settings.password_min_length} 位")
|
||||
|
||||
if len(password) > settings.password_max_length:
|
||||
errors.append(f"密码长度不能超过 {settings.password_max_length} 位")
|
||||
|
||||
if settings.password_require_uppercase and not re.search(r"[A-Z]", password):
|
||||
errors.append("至少包含一个大写字母")
|
||||
|
||||
if settings.password_require_lowercase and not re.search(r"[a-z]", password):
|
||||
errors.append("至少包含一个小写字母")
|
||||
|
||||
if settings.password_require_digit and not re.search(r"\d", password):
|
||||
errors.append("至少包含一个数字")
|
||||
|
||||
if settings.password_require_special and not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
|
||||
errors.append("至少包含一个特殊字符")
|
||||
|
||||
if errors:
|
||||
raise PasswordValidationError(";".join(errors))
|
||||
|
||||
async def get_current_user(self, token: str) -> User:
|
||||
"""
|
||||
从令牌获取当前用户
|
||||
|
||||
Args:
|
||||
token: 访问令牌
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
|
||||
Raises:
|
||||
TokenError: 令牌无效
|
||||
TokenExpiredError: 令牌已过期
|
||||
UserNotFoundError: 用户不存在
|
||||
UserDisabledError: 用户被禁用
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError()
|
||||
except jwt.InvalidTokenError:
|
||||
raise TokenError()
|
||||
|
||||
# 验证令牌类型
|
||||
if payload.get("type") != "access":
|
||||
raise TokenError("无效的令牌类型")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise TokenError()
|
||||
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
return user
|
||||
|
||||
934
app/services/balance.py
Normal file
934
app/services/balance.py
Normal file
@@ -0,0 +1,934 @@
|
||||
"""
|
||||
余额服务
|
||||
|
||||
处理余额相关的业务逻辑。
|
||||
|
||||
设计说明:
|
||||
- 所有金额操作使用整数单位(units),避免浮点精度问题
|
||||
- 扣款操作使用行级锁(悲观锁)确保原子性
|
||||
- 充值操作使用乐观锁,配合重试机制
|
||||
- 每笔操作都记录交易流水
|
||||
|
||||
预扣款流程(内部方法,用于耗时付费操作):
|
||||
1. pre_authorize() - 预扣款,冻结金额,快速释放锁,返回交易ID
|
||||
2. 执行耗时的付费操作(使用交易ID追踪)
|
||||
3. confirm() 或 cancel() - 根据操作结果确认或取消
|
||||
|
||||
推荐使用上下文管理器 deduction_context() 自动处理确认/取消。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncIterator, Callable, Awaitable, TypeVar
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import (
|
||||
AppException,
|
||||
ResourceNotFoundError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.models.balance import (
|
||||
UserBalance,
|
||||
BalanceTransaction,
|
||||
TransactionType,
|
||||
TransactionStatus,
|
||||
)
|
||||
from app.repositories.balance import BalanceRepository, TransactionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InsufficientBalanceError(AppException):
|
||||
"""余额不足"""
|
||||
|
||||
def __init__(self, required: int, available: int):
|
||||
super().__init__(
|
||||
f"余额不足,需要 {required / 1000:.2f},当前可用 {available / 1000:.2f}",
|
||||
"INSUFFICIENT_BALANCE",
|
||||
{"required_units": required, "available_units": available},
|
||||
)
|
||||
|
||||
|
||||
class DuplicateTransactionError(AppException):
|
||||
"""重复交易"""
|
||||
|
||||
def __init__(self, idempotency_key: str):
|
||||
super().__init__(
|
||||
"该交易已处理",
|
||||
"DUPLICATE_TRANSACTION",
|
||||
{"idempotency_key": idempotency_key},
|
||||
)
|
||||
|
||||
|
||||
class ConcurrencyError(AppException):
|
||||
"""并发冲突"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
"操作冲突,请重试",
|
||||
"CONCURRENCY_ERROR",
|
||||
)
|
||||
|
||||
|
||||
class TransactionNotFoundError(AppException):
|
||||
"""交易不存在"""
|
||||
|
||||
def __init__(self, transaction_id: str):
|
||||
super().__init__(
|
||||
"交易记录不存在",
|
||||
"TRANSACTION_NOT_FOUND",
|
||||
{"transaction_id": transaction_id},
|
||||
)
|
||||
|
||||
|
||||
class TransactionStateError(AppException):
|
||||
"""交易状态错误"""
|
||||
|
||||
def __init__(self, transaction_id: str, current_status: str, expected_status: str = "pending"):
|
||||
super().__init__(
|
||||
f"交易状态无效:当前 {current_status},预期 {expected_status}",
|
||||
"TRANSACTION_STATE_ERROR",
|
||||
{
|
||||
"transaction_id": transaction_id,
|
||||
"current_status": current_status,
|
||||
"expected_status": expected_status,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreAuthResult:
|
||||
"""
|
||||
预授权结果
|
||||
|
||||
包含交易ID和相关信息,用于后续确认或取消操作。
|
||||
"""
|
||||
|
||||
transaction_id: str
|
||||
"""交易ID,用于后续 confirm/cancel 操作"""
|
||||
|
||||
user_id: str
|
||||
"""用户ID"""
|
||||
|
||||
amount_units: int
|
||||
"""预扣款金额(单位额度)"""
|
||||
|
||||
frozen_at: datetime
|
||||
"""冻结时间"""
|
||||
|
||||
@property
|
||||
def amount_display(self) -> str:
|
||||
"""显示金额(2位小数)"""
|
||||
return f"{self.amount_units / 1000:.2f}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeductionResult:
|
||||
"""
|
||||
扣款结果
|
||||
|
||||
包含扣款操作的完整信息。
|
||||
"""
|
||||
|
||||
transaction_id: str
|
||||
"""交易ID"""
|
||||
|
||||
status: TransactionStatus
|
||||
"""交易状态"""
|
||||
|
||||
amount_units: int
|
||||
"""实际扣款金额(单位额度)"""
|
||||
|
||||
balance_before: int
|
||||
"""扣款前余额"""
|
||||
|
||||
balance_after: int
|
||||
"""扣款后余额"""
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
"""是否扣款成功"""
|
||||
return self.status == TransactionStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def amount_display(self) -> str:
|
||||
"""显示金额"""
|
||||
return f"{abs(self.amount_units) / 1000:.2f}"
|
||||
|
||||
@property
|
||||
def balance_before_display(self) -> str:
|
||||
"""显示扣款前余额"""
|
||||
return f"{self.balance_before / 1000:.2f}"
|
||||
|
||||
@property
|
||||
def balance_after_display(self) -> str:
|
||||
"""显示扣款后余额"""
|
||||
return f"{self.balance_after / 1000:.2f}"
|
||||
|
||||
|
||||
class BalanceService:
|
||||
"""余额服务"""
|
||||
|
||||
# 乐观锁最大重试次数
|
||||
MAX_RETRIES = 3
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化余额服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.balance_repo = BalanceRepository(session)
|
||||
self.transaction_repo = TransactionRepository(session)
|
||||
|
||||
# ============================================================
|
||||
# 余额查询
|
||||
# ============================================================
|
||||
|
||||
async def get_balance(self, user_id: str) -> UserBalance:
|
||||
"""
|
||||
获取用户余额
|
||||
|
||||
如果用户没有余额账户,自动创建一个。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户
|
||||
"""
|
||||
balance = await self.balance_repo.get_or_create(user_id)
|
||||
await self.balance_repo.commit()
|
||||
return balance
|
||||
|
||||
async def get_balance_detail(self, user_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
获取用户余额详情
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额详情字典
|
||||
"""
|
||||
balance = await self.get_balance(user_id)
|
||||
return {
|
||||
"user_id": balance.user_id,
|
||||
"balance_units": balance.balance,
|
||||
"frozen_units": balance.frozen_balance,
|
||||
"available_units": balance.available_balance,
|
||||
"total_recharged_units": balance.total_recharged,
|
||||
"total_consumed_units": balance.total_consumed,
|
||||
}
|
||||
|
||||
async def get_transactions(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
transaction_type: TransactionType | None = None,
|
||||
) -> tuple[list[BalanceTransaction], int]:
|
||||
"""
|
||||
获取用户交易记录
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
transaction_type: 交易类型过滤
|
||||
|
||||
Returns:
|
||||
(交易记录列表, 总数)
|
||||
"""
|
||||
transactions = await self.transaction_repo.get_by_user_id(
|
||||
user_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
transaction_type=transaction_type,
|
||||
)
|
||||
total = await self.transaction_repo.count_by_user_id(
|
||||
user_id,
|
||||
transaction_type=transaction_type,
|
||||
)
|
||||
return transactions, total
|
||||
|
||||
# ============================================================
|
||||
# 扣款操作(使用行级锁 - 悲观锁)
|
||||
# ============================================================
|
||||
|
||||
async def deduct(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BalanceTransaction:
|
||||
"""
|
||||
扣款
|
||||
|
||||
使用行级锁确保原子性,防止并发扣款导致余额变负。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度,正数)
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
idempotency_key: 幂等键
|
||||
|
||||
Returns:
|
||||
交易记录
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
DuplicateTransactionError: 重复交易
|
||||
"""
|
||||
if amount_units <= 0:
|
||||
raise ValidationError("扣款金额必须大于 0")
|
||||
|
||||
# 检查幂等性
|
||||
if idempotency_key:
|
||||
existing = await self.transaction_repo.get_by_idempotency_key(
|
||||
idempotency_key
|
||||
)
|
||||
if existing:
|
||||
raise DuplicateTransactionError(idempotency_key)
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_or_create_for_update(user_id)
|
||||
|
||||
# 检查可用余额
|
||||
if balance.available_balance < amount_units:
|
||||
raise InsufficientBalanceError(amount_units, balance.available_balance)
|
||||
|
||||
# 记录扣款前余额
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行扣款
|
||||
balance.balance -= amount_units
|
||||
balance.total_consumed += amount_units
|
||||
balance.version += 1
|
||||
|
||||
# 创建交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.DEDUCTION,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount=-amount_units, # 负数表示支出
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 扣款成功: {amount_units} 单位, "
|
||||
f"余额 {balance_before} -> {balance.balance}"
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
# ============================================================
|
||||
# 预扣款流程(内部方法,用于耗时付费操作)
|
||||
# ============================================================
|
||||
|
||||
async def pre_authorize(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> PreAuthResult:
|
||||
"""
|
||||
预授权扣款(内部方法)
|
||||
|
||||
冻结指定金额,快速释放数据库锁,返回交易ID供后续操作使用。
|
||||
此方法设计用于耗时的付费操作场景。
|
||||
|
||||
使用流程:
|
||||
1. 调用 pre_authorize() 获取 PreAuthResult
|
||||
2. 执行可能失败的耗时操作
|
||||
3. 根据操作结果调用 confirm() 或 cancel()
|
||||
|
||||
推荐使用 deduction_context() 上下文管理器自动处理。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度,正数)
|
||||
reference_type: 关联业务类型(如 api_call, service)
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
|
||||
Returns:
|
||||
PreAuthResult: 预授权结果,包含交易ID
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
ValidationError: 参数无效
|
||||
"""
|
||||
if amount_units <= 0:
|
||||
raise ValidationError("预扣款金额必须大于 0")
|
||||
|
||||
# 获取余额账户并加锁(短暂持有)
|
||||
balance = await self.balance_repo.get_or_create_for_update(user_id)
|
||||
|
||||
# 检查可用余额
|
||||
if balance.available_balance < amount_units:
|
||||
raise InsufficientBalanceError(amount_units, balance.available_balance)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# 执行冻结
|
||||
balance.frozen_balance += amount_units
|
||||
balance.version += 1
|
||||
|
||||
# 创建待处理交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.DEDUCTION,
|
||||
status=TransactionStatus.PENDING,
|
||||
amount=-amount_units,
|
||||
balance_before=balance.balance,
|
||||
balance_after=balance.balance, # 尚未实际扣款
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
remark=f"预授权冻结: {amount_units} 单位",
|
||||
)
|
||||
|
||||
# 快速提交释放锁
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 预授权成功: {amount_units} 单位, "
|
||||
f"交易ID: {transaction.id}"
|
||||
)
|
||||
|
||||
return PreAuthResult(
|
||||
transaction_id=transaction.id,
|
||||
user_id=user_id,
|
||||
amount_units=amount_units,
|
||||
frozen_at=now,
|
||||
)
|
||||
|
||||
async def confirm(
|
||||
self,
|
||||
transaction_id: str,
|
||||
*,
|
||||
actual_amount_units: int | None = None,
|
||||
) -> DeductionResult:
|
||||
"""
|
||||
确认预授权扣款(内部方法)
|
||||
|
||||
将预冻结的金额实际扣除。支持部分扣款。
|
||||
|
||||
Args:
|
||||
transaction_id: 预授权交易 ID
|
||||
actual_amount_units: 实际扣款金额(可选,用于部分扣款,默认全额)
|
||||
|
||||
Returns:
|
||||
DeductionResult: 扣款结果
|
||||
|
||||
Raises:
|
||||
TransactionNotFoundError: 交易不存在
|
||||
TransactionStateError: 交易状态不是 PENDING
|
||||
ValidationError: 参数无效
|
||||
"""
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
if not transaction:
|
||||
raise TransactionNotFoundError(transaction_id)
|
||||
|
||||
if transaction.status != TransactionStatus.PENDING:
|
||||
raise TransactionStateError(
|
||||
transaction_id,
|
||||
transaction.status.value,
|
||||
)
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_by_user_id_for_update(transaction.user_id)
|
||||
if not balance:
|
||||
raise ResourceNotFoundError("余额账户不存在")
|
||||
|
||||
frozen_amount = abs(transaction.amount)
|
||||
|
||||
# 确定实际扣款金额
|
||||
if actual_amount_units is not None:
|
||||
if actual_amount_units <= 0:
|
||||
raise ValidationError("实际扣款金额必须大于 0")
|
||||
if actual_amount_units > frozen_amount:
|
||||
raise ValidationError(
|
||||
f"实际扣款金额 ({actual_amount_units}) 不能超过预授权金额 ({frozen_amount})"
|
||||
)
|
||||
deduct_amount = actual_amount_units
|
||||
else:
|
||||
deduct_amount = frozen_amount
|
||||
|
||||
# 检查冻结金额
|
||||
if balance.frozen_balance < frozen_amount:
|
||||
raise ValidationError("冻结金额不足,可能已被其他操作修改")
|
||||
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行扣款:解冻全部,扣除实际金额
|
||||
balance.frozen_balance -= frozen_amount
|
||||
balance.balance -= deduct_amount
|
||||
balance.total_consumed += deduct_amount
|
||||
balance.version += 1
|
||||
|
||||
# 更新交易记录
|
||||
transaction.status = TransactionStatus.COMPLETED
|
||||
transaction.amount = -deduct_amount # 更新为实际扣款金额
|
||||
transaction.balance_after = balance.balance
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {transaction.user_id} 确认扣款: {deduct_amount} 单位, "
|
||||
f"余额 {balance_before} -> {balance.balance}"
|
||||
)
|
||||
|
||||
return DeductionResult(
|
||||
transaction_id=transaction.id,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount_units=deduct_amount,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
)
|
||||
|
||||
async def cancel(
|
||||
self,
|
||||
transaction_id: str,
|
||||
*,
|
||||
reason: str | None = None,
|
||||
) -> DeductionResult:
|
||||
"""
|
||||
取消预授权扣款(内部方法)
|
||||
|
||||
解冻预授权的金额,退回用户可用余额。
|
||||
|
||||
Args:
|
||||
transaction_id: 预授权交易 ID
|
||||
reason: 取消原因(可选,记录在日志中)
|
||||
|
||||
Returns:
|
||||
DeductionResult: 取消结果
|
||||
|
||||
Raises:
|
||||
TransactionNotFoundError: 交易不存在
|
||||
TransactionStateError: 交易状态不是 PENDING
|
||||
"""
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
if not transaction:
|
||||
raise TransactionNotFoundError(transaction_id)
|
||||
|
||||
if transaction.status != TransactionStatus.PENDING:
|
||||
raise TransactionStateError(
|
||||
transaction_id,
|
||||
transaction.status.value,
|
||||
)
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_by_user_id_for_update(transaction.user_id)
|
||||
if not balance:
|
||||
raise ResourceNotFoundError("余额账户不存在")
|
||||
|
||||
frozen_amount = abs(transaction.amount)
|
||||
|
||||
# 解冻
|
||||
balance.frozen_balance -= frozen_amount
|
||||
balance.version += 1
|
||||
|
||||
# 更新交易记录
|
||||
transaction.status = TransactionStatus.CANCELLED
|
||||
if reason:
|
||||
transaction.remark = f"{transaction.remark or ''}; 取消原因: {reason}"
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {transaction.user_id} 取消预授权: {frozen_amount} 单位"
|
||||
+ (f", 原因: {reason}" if reason else "")
|
||||
)
|
||||
|
||||
return DeductionResult(
|
||||
transaction_id=transaction.id,
|
||||
status=TransactionStatus.CANCELLED,
|
||||
amount_units=0, # 实际未扣款
|
||||
balance_before=balance.balance,
|
||||
balance_after=balance.balance,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def deduction_context(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
auto_cancel_on_error: bool = True,
|
||||
) -> AsyncIterator[PreAuthResult]:
|
||||
"""
|
||||
扣款上下文管理器(推荐使用)
|
||||
|
||||
提供简便的预扣款流程,自动处理确认和取消。
|
||||
异常时自动取消预授权,退回冻结金额。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
async with balance_service.deduction_context(
|
||||
user_id,
|
||||
1000, # 扣款金额(单位额度)
|
||||
reference_type="api_call",
|
||||
description="API调用费用",
|
||||
) as pre_auth:
|
||||
# pre_auth.transaction_id 可用于追踪
|
||||
# 执行可能失败的耗时操作
|
||||
result = await call_external_api()
|
||||
if not result.success:
|
||||
raise Exception("API 调用失败")
|
||||
# 成功退出时自动确认扣款
|
||||
# 异常退出时自动取消预授权(如果 auto_cancel_on_error=True)
|
||||
```
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度)
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
auto_cancel_on_error: 异常时是否自动取消(默认 True)
|
||||
|
||||
Yields:
|
||||
PreAuthResult: 预授权结果,包含交易ID
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
"""
|
||||
# 第一阶段:预授权
|
||||
pre_auth = await self.pre_authorize(
|
||||
user_id,
|
||||
amount_units,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
try:
|
||||
yield pre_auth
|
||||
# 正常退出:确认扣款
|
||||
await self.confirm(pre_auth.transaction_id)
|
||||
except Exception as e:
|
||||
# 异常退出:取消预授权
|
||||
if auto_cancel_on_error:
|
||||
try:
|
||||
await self.cancel(
|
||||
pre_auth.transaction_id,
|
||||
reason=f"操作失败: {str(e)[:200]}",
|
||||
)
|
||||
except Exception as cancel_error:
|
||||
logger.error(
|
||||
f"取消预授权失败: {pre_auth.transaction_id}, "
|
||||
f"错误: {cancel_error}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def execute_with_deduction(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
operation: Callable[[PreAuthResult], Awaitable[T]],
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> tuple[DeductionResult, T]:
|
||||
"""
|
||||
执行带扣款的操作(函数式接口)
|
||||
|
||||
预扣款后执行指定操作,根据操作结果自动确认或取消。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
async def call_api(pre_auth: PreAuthResult):
|
||||
return await external_api.call(
|
||||
transaction_id=pre_auth.transaction_id,
|
||||
amount=pre_auth.amount_display,
|
||||
)
|
||||
|
||||
deduction_result, api_result = await balance_service.execute_with_deduction(
|
||||
user_id,
|
||||
1000,
|
||||
call_api,
|
||||
reference_type="api_call",
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 扣款金额(单位额度)
|
||||
operation: 要执行的异步操作,接收 PreAuthResult 参数
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
|
||||
Returns:
|
||||
(DeductionResult, operation返回值): 扣款结果和操作结果
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 余额不足
|
||||
Exception: 操作抛出的异常(预授权会自动取消)
|
||||
"""
|
||||
pre_auth = await self.pre_authorize(
|
||||
user_id,
|
||||
amount_units,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
try:
|
||||
# 执行操作
|
||||
result = await operation(pre_auth)
|
||||
# 成功:确认扣款
|
||||
deduction_result = await self.confirm(pre_auth.transaction_id)
|
||||
return deduction_result, result
|
||||
except Exception as e:
|
||||
# 失败:取消预授权
|
||||
try:
|
||||
await self.cancel(
|
||||
pre_auth.transaction_id,
|
||||
reason=f"操作失败: {str(e)[:200]}",
|
||||
)
|
||||
except Exception as cancel_error:
|
||||
logger.error(
|
||||
f"取消预授权失败: {pre_auth.transaction_id}, "
|
||||
f"错误: {cancel_error}"
|
||||
)
|
||||
raise
|
||||
|
||||
# ============================================================
|
||||
# 兼容方法(保留旧接口)
|
||||
# ============================================================
|
||||
|
||||
async def deduct_with_freeze(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
冻结并预扣款(兼容方法,推荐使用 pre_authorize)
|
||||
|
||||
Returns:
|
||||
交易ID
|
||||
"""
|
||||
result = await self.pre_authorize(
|
||||
user_id,
|
||||
amount_units,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
)
|
||||
return result.transaction_id
|
||||
|
||||
async def confirm_frozen_deduction(self, transaction_id: str) -> BalanceTransaction:
|
||||
"""
|
||||
确认冻结扣款(兼容方法,推荐使用 confirm)
|
||||
"""
|
||||
await self.confirm(transaction_id)
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
return transaction # type: ignore
|
||||
|
||||
async def cancel_frozen_deduction(self, transaction_id: str) -> BalanceTransaction:
|
||||
"""
|
||||
取消冻结扣款(兼容方法,推荐使用 cancel)
|
||||
"""
|
||||
await self.cancel(transaction_id)
|
||||
transaction = await self.transaction_repo.get_by_id(transaction_id)
|
||||
return transaction # type: ignore
|
||||
|
||||
# ============================================================
|
||||
# 充值操作(使用乐观锁 + 重试)
|
||||
# ============================================================
|
||||
|
||||
async def recharge(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
reference_type: str | None = None,
|
||||
reference_id: str | None = None,
|
||||
description: str | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
) -> BalanceTransaction:
|
||||
"""
|
||||
充值
|
||||
|
||||
使用乐观锁,配合重试机制处理并发冲突。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
amount_units: 充值金额(单位额度,正数)
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
description: 交易描述
|
||||
idempotency_key: 幂等键
|
||||
|
||||
Returns:
|
||||
交易记录
|
||||
|
||||
Raises:
|
||||
DuplicateTransactionError: 重复交易
|
||||
ConcurrencyError: 并发冲突(重试失败)
|
||||
"""
|
||||
if amount_units <= 0:
|
||||
raise ValidationError("充值金额必须大于 0")
|
||||
|
||||
# 检查幂等性
|
||||
if idempotency_key:
|
||||
existing = await self.transaction_repo.get_by_idempotency_key(
|
||||
idempotency_key
|
||||
)
|
||||
if existing:
|
||||
raise DuplicateTransactionError(idempotency_key)
|
||||
|
||||
# 乐观锁重试
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
balance = await self.balance_repo.get_or_create(user_id)
|
||||
balance_before = balance.balance
|
||||
|
||||
# 尝试更新余额
|
||||
success = await self.balance_repo.update_balance_optimistic(
|
||||
balance,
|
||||
amount_units,
|
||||
is_recharge=True,
|
||||
)
|
||||
|
||||
if success:
|
||||
# 创建交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.RECHARGE,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount=amount_units, # 正数表示收入
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
reference_type=reference_type,
|
||||
reference_id=reference_id,
|
||||
description=description,
|
||||
idempotency_key=idempotency_key,
|
||||
)
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 充值成功: {amount_units} 单位, "
|
||||
f"余额 {balance_before} -> {balance.balance}"
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
# 冲突,重试
|
||||
logger.warning(
|
||||
f"用户 {user_id} 充值冲突,重试 {attempt + 1}/{self.MAX_RETRIES}"
|
||||
)
|
||||
await self.balance_repo.rollback()
|
||||
|
||||
# 重试失败
|
||||
raise ConcurrencyError()
|
||||
|
||||
# ============================================================
|
||||
# 管理员操作
|
||||
# ============================================================
|
||||
|
||||
async def admin_adjust(
|
||||
self,
|
||||
user_id: str,
|
||||
amount_units: int,
|
||||
*,
|
||||
operator_id: str,
|
||||
reason: str,
|
||||
) -> BalanceTransaction:
|
||||
"""
|
||||
管理员调整余额
|
||||
|
||||
Args:
|
||||
user_id: 目标用户 ID
|
||||
amount_units: 调整金额(正数增加,负数减少)
|
||||
operator_id: 操作人 ID
|
||||
reason: 调整原因
|
||||
|
||||
Returns:
|
||||
交易记录
|
||||
|
||||
Raises:
|
||||
InsufficientBalanceError: 减少金额时余额不足
|
||||
"""
|
||||
if amount_units == 0:
|
||||
raise ValidationError("调整金额不能为 0")
|
||||
|
||||
# 获取余额账户并加锁
|
||||
balance = await self.balance_repo.get_or_create_for_update(user_id)
|
||||
|
||||
# 减少时检查余额
|
||||
if amount_units < 0 and balance.available_balance < abs(amount_units):
|
||||
raise InsufficientBalanceError(
|
||||
abs(amount_units), balance.available_balance
|
||||
)
|
||||
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行调整
|
||||
balance.balance += amount_units
|
||||
if amount_units > 0:
|
||||
balance.total_recharged += amount_units
|
||||
balance.version += 1
|
||||
|
||||
# 创建交易记录
|
||||
transaction = await self.transaction_repo.create(
|
||||
user_id=user_id,
|
||||
balance_account_id=balance.id,
|
||||
transaction_type=TransactionType.ADJUSTMENT,
|
||||
status=TransactionStatus.COMPLETED,
|
||||
amount=amount_units,
|
||||
balance_before=balance_before,
|
||||
balance_after=balance.balance,
|
||||
description=reason,
|
||||
operator_id=operator_id,
|
||||
remark=f"管理员调整: {reason}",
|
||||
)
|
||||
|
||||
await self.balance_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"管理员 {operator_id} 调整用户 {user_id} 余额: {amount_units} 单位, "
|
||||
f"原因: {reason}"
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
395
app/services/oauth2.py
Normal file
395
app/services/oauth2.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
OAuth2 认证服务
|
||||
|
||||
处理 OAuth2 认证流程,支持主备端点自动切换。
|
||||
当首选端点不可达时,自动回退到备用端点。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError,
|
||||
ResourceConflictError,
|
||||
)
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.auth import TokenResponse
|
||||
from app.schemas.oauth2 import OAuth2TokenData, OAuth2UserInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth2 提供商标识
|
||||
OAUTH_PROVIDER_LINUXDO = "linuxdo"
|
||||
|
||||
|
||||
class OAuth2EndpointError(AuthenticationError):
|
||||
"""OAuth2 端点错误"""
|
||||
|
||||
def __init__(self, message: str = "OAuth2 服务不可用"):
|
||||
super().__init__(message, "OAUTH2_ENDPOINT_ERROR")
|
||||
|
||||
|
||||
class OAuth2StateError(AuthenticationError):
|
||||
"""OAuth2 状态验证错误"""
|
||||
|
||||
def __init__(self, message: str = "无效的状态码"):
|
||||
super().__init__(message, "OAUTH2_STATE_ERROR")
|
||||
|
||||
|
||||
class OAuth2Service:
|
||||
"""
|
||||
OAuth2 认证服务
|
||||
|
||||
特性:
|
||||
- 支持主备端点自动切换
|
||||
- 首选端点请求失败时自动回退到备用端点
|
||||
- 状态码验证防止 CSRF 攻击
|
||||
"""
|
||||
|
||||
# 存储状态码(生产环境应使用 Redis)
|
||||
_state_store: dict[str, bool] = {}
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化 OAuth2 服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
|
||||
# 端点配置
|
||||
self._endpoints = {
|
||||
"authorize": {
|
||||
"primary": settings.oauth2_authorize_endpoint,
|
||||
"reserve": settings.oauth2_authorize_endpoint_reserve,
|
||||
},
|
||||
"token": {
|
||||
"primary": settings.oauth2_token_endpoint,
|
||||
"reserve": settings.oauth2_token_endpoint_reserve,
|
||||
},
|
||||
"userinfo": {
|
||||
"primary": settings.oauth2_user_info_endpoint,
|
||||
"reserve": settings.oauth2_user_info_endpoint_reserve,
|
||||
},
|
||||
}
|
||||
|
||||
self._timeout = settings.oauth2_request_timeout
|
||||
|
||||
def generate_authorize_url(self, redirect_uri: str) -> tuple[str, str]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
|
||||
Args:
|
||||
redirect_uri: 回调 URL
|
||||
|
||||
Returns:
|
||||
(授权 URL, 状态码)
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
self._state_store[state] = True # 存储状态码
|
||||
|
||||
params = {
|
||||
"client_id": settings.oauth2_client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
"scope": "read", # 根据实际需要调整
|
||||
}
|
||||
|
||||
# 使用首选授权端点
|
||||
authorize_url = f"{self._endpoints['authorize']['primary']}?{urlencode(params)}"
|
||||
|
||||
return authorize_url, state
|
||||
|
||||
def validate_state(self, state: str) -> bool:
|
||||
"""
|
||||
验证状态码(防 CSRF)
|
||||
|
||||
Args:
|
||||
state: 状态码
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
if state in self._state_store:
|
||||
del self._state_store[state] # 使用后立即删除
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _request_with_fallback(
|
||||
self,
|
||||
endpoint_type: str,
|
||||
method: str,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
带回退的 HTTP 请求
|
||||
|
||||
首先尝试首选端点,失败后自动切换到备用端点。
|
||||
|
||||
Args:
|
||||
endpoint_type: 端点类型(token/userinfo)
|
||||
method: HTTP 方法
|
||||
**kwargs: 请求参数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 所有端点都不可用
|
||||
"""
|
||||
endpoints = self._endpoints[endpoint_type]
|
||||
last_error: Exception | None = None
|
||||
|
||||
for endpoint_name, url in [("primary", endpoints["primary"]), ("reserve", endpoints["reserve"])]:
|
||||
try:
|
||||
logger.debug(f"尝试 OAuth2 {endpoint_type} 端点 ({endpoint_name}): {url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
if method.upper() == "POST":
|
||||
response = await client.post(url, **kwargs)
|
||||
else:
|
||||
response = await client.get(url, **kwargs)
|
||||
|
||||
# 检查 HTTP 状态
|
||||
if response.status_code >= 500:
|
||||
logger.warning(
|
||||
f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) "
|
||||
f"返回服务器错误: {response.status_code}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"OAuth2 {endpoint_type} 请求成功 ({endpoint_name})")
|
||||
return response
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 超时: {e}")
|
||||
last_error = e
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 连接失败: {e}")
|
||||
last_error = e
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 请求异常: {e}")
|
||||
last_error = e
|
||||
|
||||
# 所有端点都失败
|
||||
error_msg = f"OAuth2 {endpoint_type} 服务不可用"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
raise OAuth2EndpointError(error_msg)
|
||||
|
||||
async def exchange_code_for_token(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuth2TokenData:
|
||||
"""
|
||||
用授权码换取访问令牌
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
redirect_uri: 回调 URL(必须与授权时一致)
|
||||
|
||||
Returns:
|
||||
OAuth2 令牌数据
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 端点不可用
|
||||
AuthenticationError: 换取令牌失败
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": settings.oauth2_client_id,
|
||||
"client_secret": settings.oauth2_client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
response = await self._request_with_fallback(
|
||||
"token",
|
||||
"POST",
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth2 token 响应错误: {response.status_code} - {response.text}")
|
||||
raise AuthenticationError(
|
||||
f"获取访问令牌失败: {response.status_code}",
|
||||
"OAUTH2_TOKEN_ERROR",
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
return OAuth2TokenData(**token_data)
|
||||
|
||||
async def get_user_info(self, access_token: str) -> OAuth2UserInfo:
|
||||
"""
|
||||
获取 OAuth2 用户信息
|
||||
|
||||
Args:
|
||||
access_token: OAuth2 访问令牌
|
||||
|
||||
Returns:
|
||||
用户信息
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 端点不可用
|
||||
AuthenticationError: 获取用户信息失败
|
||||
"""
|
||||
response = await self._request_with_fallback(
|
||||
"userinfo",
|
||||
"GET",
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth2 userinfo 响应错误: {response.status_code} - {response.text}")
|
||||
raise AuthenticationError(
|
||||
f"获取用户信息失败: {response.status_code}",
|
||||
"OAUTH2_USERINFO_ERROR",
|
||||
)
|
||||
|
||||
user_data = response.json()
|
||||
return OAuth2UserInfo(**user_data)
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> tuple[User, TokenResponse, bool]:
|
||||
"""
|
||||
完整的 OAuth2 认证流程
|
||||
|
||||
1. 验证状态码
|
||||
2. 用授权码换取令牌
|
||||
3. 获取用户信息
|
||||
4. 创建或更新用户
|
||||
5. 生成 JWT 令牌
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
state: 状态码
|
||||
redirect_uri: 回调 URL
|
||||
|
||||
Returns:
|
||||
(用户对象, JWT 令牌响应, 是否新用户)
|
||||
|
||||
Raises:
|
||||
OAuth2StateError: 状态码无效
|
||||
OAuth2EndpointError: OAuth2 服务不可用
|
||||
AuthenticationError: 认证失败
|
||||
"""
|
||||
# 1. 验证状态码
|
||||
if not self.validate_state(state):
|
||||
raise OAuth2StateError()
|
||||
|
||||
# 2. 换取令牌
|
||||
oauth_token = await self.exchange_code_for_token(code, redirect_uri)
|
||||
|
||||
# 3. 获取用户信息
|
||||
oauth_user = await self.get_user_info(oauth_token.access_token)
|
||||
|
||||
# 4. 查找或创建用户
|
||||
user, is_new_user = await self._get_or_create_user(oauth_user)
|
||||
|
||||
# 5. 生成 JWT 令牌
|
||||
tokens = self._create_tokens(user)
|
||||
|
||||
return user, tokens, is_new_user
|
||||
|
||||
async def _get_or_create_user(
|
||||
self,
|
||||
oauth_user: OAuth2UserInfo,
|
||||
) -> tuple[User, bool]:
|
||||
"""
|
||||
根据 OAuth2 用户信息获取或创建本地用户
|
||||
|
||||
Args:
|
||||
oauth_user: OAuth2 用户信息
|
||||
|
||||
Returns:
|
||||
(用户对象, 是否新创建)
|
||||
"""
|
||||
oauth_user_id = str(oauth_user.id)
|
||||
|
||||
# 先通过 OAuth ID 查找
|
||||
user = await self.user_repo.get_by_oauth(
|
||||
provider=OAUTH_PROVIDER_LINUXDO,
|
||||
oauth_user_id=oauth_user_id,
|
||||
)
|
||||
|
||||
if user:
|
||||
# 更新用户信息(头像等可能变化)
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
nickname=oauth_user.name or oauth_user.username,
|
||||
avatar_url=oauth_user.avatar_url,
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
return user, False
|
||||
|
||||
# 检查用户名是否已存在
|
||||
username = oauth_user.username.lower()
|
||||
existing_user = await self.user_repo.get_by_username(username)
|
||||
if existing_user:
|
||||
# 用户名冲突,添加后缀
|
||||
username = f"{username}_{oauth_user_id[:8]}"
|
||||
|
||||
# 创建新用户
|
||||
user = await self.user_repo.create(
|
||||
username=username,
|
||||
email=oauth_user.email,
|
||||
nickname=oauth_user.name or oauth_user.username,
|
||||
avatar_url=oauth_user.avatar_url,
|
||||
oauth_provider=OAUTH_PROVIDER_LINUXDO,
|
||||
oauth_user_id=oauth_user_id,
|
||||
hashed_password=None, # OAuth2 用户无密码
|
||||
is_active=oauth_user.active,
|
||||
)
|
||||
|
||||
await self.user_repo.commit()
|
||||
logger.info(f"创建 OAuth2 用户: {user.username} (provider={OAUTH_PROVIDER_LINUXDO})")
|
||||
|
||||
return user, True
|
||||
|
||||
def _create_tokens(self, user: User) -> TokenResponse:
|
||||
"""
|
||||
为用户创建 JWT 令牌
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
|
||||
Returns:
|
||||
令牌响应
|
||||
"""
|
||||
access_token = create_access_token(
|
||||
subject=user.id,
|
||||
extra_claims={
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
"oauth_provider": user.oauth_provider,
|
||||
},
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(subject=user.id)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
570
app/services/redeem_code.py
Normal file
570
app/services/redeem_code.py
Normal file
@@ -0,0 +1,570 @@
|
||||
"""
|
||||
兑换码服务
|
||||
|
||||
处理兑换码相关的业务逻辑。
|
||||
|
||||
设计说明:
|
||||
- 兑换操作使用行级锁确保原子性
|
||||
- 支持批量生成和导入导出
|
||||
- 记录完整的使用日志
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import (
|
||||
AppException,
|
||||
ResourceNotFoundError,
|
||||
ValidationError,
|
||||
)
|
||||
from app.models.redeem_code import (
|
||||
RedeemCode,
|
||||
RedeemCodeBatch,
|
||||
RedeemCodeUsageLog,
|
||||
RedeemCodeStatus,
|
||||
generate_redeem_code,
|
||||
)
|
||||
from app.models.balance import TransactionType
|
||||
from app.repositories.redeem_code import (
|
||||
RedeemCodeRepository,
|
||||
RedeemCodeBatchRepository,
|
||||
RedeemCodeUsageLogRepository,
|
||||
)
|
||||
from app.services.balance import BalanceService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedeemCodeNotFoundError(AppException):
|
||||
"""兑换码不存在"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码不存在",
|
||||
"REDEEM_CODE_NOT_FOUND",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeInvalidError(AppException):
|
||||
"""兑换码无效"""
|
||||
|
||||
def __init__(self, code: str, reason: str):
|
||||
super().__init__(
|
||||
f"兑换码无效: {reason}",
|
||||
"REDEEM_CODE_INVALID",
|
||||
{"code": code, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeExpiredError(AppException):
|
||||
"""兑换码已过期"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已过期",
|
||||
"REDEEM_CODE_EXPIRED",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeUsedError(AppException):
|
||||
"""兑换码已使用"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已使用",
|
||||
"REDEEM_CODE_USED",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeDisabledError(AppException):
|
||||
"""兑换码已禁用"""
|
||||
|
||||
def __init__(self, code: str):
|
||||
super().__init__(
|
||||
"兑换码已禁用",
|
||||
"REDEEM_CODE_DISABLED",
|
||||
{"code": code},
|
||||
)
|
||||
|
||||
|
||||
class RedeemCodeService:
|
||||
"""兑换码服务"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化兑换码服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.code_repo = RedeemCodeRepository(session)
|
||||
self.batch_repo = RedeemCodeBatchRepository(session)
|
||||
self.log_repo = RedeemCodeUsageLogRepository(session)
|
||||
self.balance_service = BalanceService(session)
|
||||
|
||||
# ============================================================
|
||||
# 用户兑换
|
||||
# ============================================================
|
||||
|
||||
async def redeem(
|
||||
self,
|
||||
user_id: str,
|
||||
code: str,
|
||||
*,
|
||||
ip_address: str | None = None,
|
||||
user_agent: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
用户兑换余额
|
||||
|
||||
使用行级锁确保原子性,防止并发兑换。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
code: 兑换码
|
||||
ip_address: 客户端 IP
|
||||
user_agent: User Agent
|
||||
|
||||
Returns:
|
||||
兑换结果
|
||||
|
||||
Raises:
|
||||
RedeemCodeNotFoundError: 兑换码不存在
|
||||
RedeemCodeInvalidError: 兑换码无效
|
||||
"""
|
||||
# 标准化兑换码
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
|
||||
# 获取兑换码并加锁
|
||||
redeem_code = await self.code_repo.get_by_code_for_update(normalized_code)
|
||||
|
||||
if not redeem_code:
|
||||
raise RedeemCodeNotFoundError(normalized_code)
|
||||
|
||||
# 验证兑换码状态
|
||||
self._validate_redeem_code(redeem_code)
|
||||
|
||||
# 获取用户当前余额
|
||||
balance = await self.balance_service.get_balance(user_id)
|
||||
balance_before = balance.balance
|
||||
|
||||
# 执行充值
|
||||
transaction = await self.balance_service.recharge(
|
||||
user_id,
|
||||
redeem_code.face_value,
|
||||
reference_type="redeem_code",
|
||||
reference_id=redeem_code.id,
|
||||
description=f"兑换码充值: {redeem_code.code}",
|
||||
)
|
||||
|
||||
# 标记兑换码已使用
|
||||
await self.code_repo.mark_as_used(redeem_code, user_id)
|
||||
|
||||
# 更新批次统计
|
||||
if redeem_code.batch_id:
|
||||
await self.batch_repo.increment_used_count(redeem_code.batch_id)
|
||||
|
||||
# 记录使用日志
|
||||
await self.log_repo.create(
|
||||
redeem_code_id=redeem_code.id,
|
||||
user_id=user_id,
|
||||
transaction_id=transaction.id,
|
||||
code_snapshot=redeem_code.code,
|
||||
face_value=redeem_code.face_value,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"用户 {user_id} 兑换成功: {redeem_code.code}, "
|
||||
f"面值 {redeem_code.face_value} 单位"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "兑换成功",
|
||||
"face_value": f"{redeem_code.face_value / 1000:.2f}",
|
||||
"balance_before": f"{balance_before / 1000:.2f}",
|
||||
"balance_after": f"{transaction.balance_after / 1000:.2f}",
|
||||
}
|
||||
|
||||
def _validate_redeem_code(self, code: RedeemCode) -> None:
|
||||
"""验证兑换码有效性"""
|
||||
if code.status == RedeemCodeStatus.DISABLED:
|
||||
raise RedeemCodeDisabledError(code.code)
|
||||
|
||||
if code.status == RedeemCodeStatus.USED or code.used_count >= code.max_uses:
|
||||
raise RedeemCodeUsedError(code.code)
|
||||
|
||||
if code.expires_at and code.expires_at < datetime.now(timezone.utc):
|
||||
raise RedeemCodeExpiredError(code.code)
|
||||
|
||||
# ============================================================
|
||||
# 管理员:批量生成
|
||||
# ============================================================
|
||||
|
||||
async def create_batch(
|
||||
self,
|
||||
name: str,
|
||||
face_value_units: int,
|
||||
count: int,
|
||||
*,
|
||||
created_by: str,
|
||||
description: str | None = None,
|
||||
max_uses: int = 1,
|
||||
expires_at: datetime | None = None,
|
||||
) -> RedeemCodeBatch:
|
||||
"""
|
||||
创建兑换码批次
|
||||
|
||||
批量生成指定数量的兑换码。
|
||||
|
||||
Args:
|
||||
name: 批次名称
|
||||
face_value_units: 面值(单位额度)
|
||||
count: 生成数量
|
||||
created_by: 创建者 ID
|
||||
description: 批次描述
|
||||
max_uses: 每个兑换码最大使用次数
|
||||
expires_at: 过期时间
|
||||
|
||||
Returns:
|
||||
创建的批次
|
||||
"""
|
||||
if face_value_units <= 0:
|
||||
raise ValidationError("面值必须大于 0")
|
||||
if count <= 0 or count > 10000:
|
||||
raise ValidationError("数量必须在 1-10000 之间")
|
||||
|
||||
# 创建批次
|
||||
batch = await self.batch_repo.create(
|
||||
name=name,
|
||||
description=description,
|
||||
face_value=face_value_units,
|
||||
total_count=count,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
# 批量生成兑换码
|
||||
codes_data = []
|
||||
generated_codes = set()
|
||||
|
||||
while len(codes_data) < count:
|
||||
new_code = generate_redeem_code()
|
||||
if new_code not in generated_codes:
|
||||
generated_codes.add(new_code)
|
||||
codes_data.append({
|
||||
"code": new_code,
|
||||
"batch_id": batch.id,
|
||||
"face_value": face_value_units,
|
||||
"max_uses": max_uses,
|
||||
"expires_at": expires_at,
|
||||
"created_by": created_by,
|
||||
})
|
||||
|
||||
await self.code_repo.bulk_create(codes_data)
|
||||
await self.batch_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"管理员 {created_by} 创建批次 '{name}': "
|
||||
f"{count} 个兑换码, 面值 {face_value_units} 单位"
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
# ============================================================
|
||||
# 管理员:导入
|
||||
# ============================================================
|
||||
|
||||
async def import_codes(
|
||||
self,
|
||||
codes: list[dict[str, Any]],
|
||||
*,
|
||||
created_by: str,
|
||||
batch_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
导入兑换码
|
||||
|
||||
Args:
|
||||
codes: 兑换码数据列表
|
||||
created_by: 创建者 ID
|
||||
batch_name: 批次名称(可选)
|
||||
|
||||
Returns:
|
||||
导入结果
|
||||
"""
|
||||
batch_id = None
|
||||
|
||||
# 创建批次(如果指定)
|
||||
if batch_name:
|
||||
# 计算总面值用于批次记录
|
||||
total_face_value = sum(c.get("face_value_units", 0) for c in codes)
|
||||
batch = await self.batch_repo.create(
|
||||
name=batch_name,
|
||||
description="导入批次",
|
||||
face_value=total_face_value // len(codes) if codes else 0,
|
||||
total_count=len(codes),
|
||||
created_by=created_by,
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
success_count = 0
|
||||
failed_codes = []
|
||||
|
||||
for code_data in codes:
|
||||
try:
|
||||
# 检查兑换码是否已存在
|
||||
existing = await self.code_repo.get_by_code(code_data["code"])
|
||||
if existing:
|
||||
failed_codes.append(code_data["code"])
|
||||
continue
|
||||
|
||||
# 创建兑换码
|
||||
await self.code_repo.create(
|
||||
code=code_data["code"].strip().upper(),
|
||||
batch_id=batch_id,
|
||||
face_value=code_data["face_value_units"],
|
||||
max_uses=code_data.get("max_uses", 1),
|
||||
expires_at=code_data.get("expires_at"),
|
||||
remark=code_data.get("remark"),
|
||||
created_by=created_by,
|
||||
)
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"导入兑换码失败: {code_data.get('code')}, {e}")
|
||||
failed_codes.append(code_data.get("code", "unknown"))
|
||||
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(
|
||||
f"管理员 {created_by} 导入兑换码: "
|
||||
f"成功 {success_count}, 失败 {len(failed_codes)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"success_count": success_count,
|
||||
"failed_count": len(failed_codes),
|
||||
"failed_codes": failed_codes,
|
||||
"batch_id": batch_id,
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# 管理员:导出
|
||||
# ============================================================
|
||||
|
||||
async def export_codes(
|
||||
self,
|
||||
*,
|
||||
batch_id: str | None = None,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
limit: int = 10000,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
导出兑换码
|
||||
|
||||
Args:
|
||||
batch_id: 批次 ID 过滤
|
||||
status: 状态过滤
|
||||
limit: 最大导出数量
|
||||
|
||||
Returns:
|
||||
兑换码数据列表
|
||||
"""
|
||||
codes = await self.code_repo.get_all_with_filters(
|
||||
batch_id=batch_id,
|
||||
status=status,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
result = []
|
||||
for code in codes:
|
||||
result.append({
|
||||
"code": code.code,
|
||||
"face_value": f"{code.face_value / 1000:.2f}",
|
||||
"status": code.status.value,
|
||||
"max_uses": code.max_uses,
|
||||
"used_count": code.used_count,
|
||||
"expires_at": code.expires_at.isoformat() if code.expires_at else None,
|
||||
"created_at": code.created_at.isoformat(),
|
||||
"used_at": code.used_at.isoformat() if code.used_at else None,
|
||||
"used_by": code.used_by,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
# ============================================================
|
||||
# 管理员:查询
|
||||
# ============================================================
|
||||
|
||||
async def get_codes(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
batch_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> tuple[list[RedeemCode], int]:
|
||||
"""
|
||||
获取兑换码列表
|
||||
|
||||
Returns:
|
||||
(兑换码列表, 总数)
|
||||
"""
|
||||
codes = await self.code_repo.get_all_with_filters(
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
status=status,
|
||||
batch_id=batch_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
total = await self.code_repo.count_with_filters(
|
||||
status=status,
|
||||
batch_id=batch_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
return codes, total
|
||||
|
||||
async def get_code_detail(self, code_id: str) -> RedeemCode:
|
||||
"""
|
||||
获取兑换码详情
|
||||
|
||||
Args:
|
||||
code_id: 兑换码 ID
|
||||
|
||||
Returns:
|
||||
兑换码记录
|
||||
"""
|
||||
code = await self.code_repo.get_by_id(code_id)
|
||||
if not code:
|
||||
raise ResourceNotFoundError("兑换码不存在", "redeem_code", code_id)
|
||||
return code
|
||||
|
||||
async def disable_code(self, code_id: str) -> RedeemCode:
|
||||
"""
|
||||
禁用兑换码
|
||||
|
||||
Args:
|
||||
code_id: 兑换码 ID
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
code = await self.get_code_detail(code_id)
|
||||
code = await self.code_repo.disable_code(code)
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(f"兑换码已禁用: {code.code}")
|
||||
return code
|
||||
|
||||
async def enable_code(self, code_id: str) -> RedeemCode:
|
||||
"""
|
||||
启用兑换码
|
||||
|
||||
Args:
|
||||
code_id: 兑换码 ID
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
code = await self.get_code_detail(code_id)
|
||||
code = await self.code_repo.enable_code(code)
|
||||
await self.code_repo.commit()
|
||||
|
||||
logger.info(f"兑换码已启用: {code.code}")
|
||||
return code
|
||||
|
||||
# ============================================================
|
||||
# 管理员:批次管理
|
||||
# ============================================================
|
||||
|
||||
async def get_batches(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> tuple[list[RedeemCodeBatch], int]:
|
||||
"""
|
||||
获取批次列表
|
||||
|
||||
Returns:
|
||||
(批次列表, 总数)
|
||||
"""
|
||||
batches = await self.batch_repo.get_all_batches(
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
total = await self.batch_repo.count()
|
||||
return batches, total
|
||||
|
||||
async def get_batch_detail(self, batch_id: str) -> RedeemCodeBatch:
|
||||
"""
|
||||
获取批次详情
|
||||
|
||||
Args:
|
||||
batch_id: 批次 ID
|
||||
|
||||
Returns:
|
||||
批次记录
|
||||
"""
|
||||
batch = await self.batch_repo.get_by_id(batch_id)
|
||||
if not batch:
|
||||
raise ResourceNotFoundError("批次不存在", "batch", batch_id)
|
||||
return batch
|
||||
|
||||
# ============================================================
|
||||
# 管理员:使用日志
|
||||
# ============================================================
|
||||
|
||||
async def get_usage_logs(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
redeem_code_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> tuple[list[RedeemCodeUsageLog], int]:
|
||||
"""
|
||||
获取使用日志
|
||||
|
||||
Returns:
|
||||
(日志列表, 总数)
|
||||
"""
|
||||
logs = await self.log_repo.get_all_with_filters(
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
redeem_code_id=redeem_code_id,
|
||||
user_id=user_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
total = await self.log_repo.count_with_filters(
|
||||
redeem_code_id=redeem_code_id,
|
||||
user_id=user_id,
|
||||
code_like=code_like,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
)
|
||||
return logs, total
|
||||
|
||||
175
app/services/user.py
Normal file
175
app/services/user.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
用户服务
|
||||
|
||||
处理用户相关的业务逻辑。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.exceptions import (
|
||||
UserAlreadyExistsError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from app.core.security import hash_password
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.user import UserCreate, UserUpdate
|
||||
|
||||
|
||||
class UserService:
|
||||
"""用户服务"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化用户服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
|
||||
async def create_user(self, user_data: UserCreate) -> User:
|
||||
"""
|
||||
创建新用户
|
||||
|
||||
Args:
|
||||
user_data: 用户注册数据
|
||||
|
||||
Returns:
|
||||
新创建的用户
|
||||
|
||||
Raises:
|
||||
UserAlreadyExistsError: 用户名或邮箱已存在
|
||||
"""
|
||||
# 检查用户名是否已存在
|
||||
if await self.user_repo.exists_by_username(user_data.username):
|
||||
raise UserAlreadyExistsError("用户名")
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
if user_data.email and await self.user_repo.exists_by_email(user_data.email):
|
||||
raise UserAlreadyExistsError("邮箱")
|
||||
|
||||
# 创建用户
|
||||
user = await self.user_repo.create(
|
||||
username=user_data.username.lower(),
|
||||
email=user_data.email.lower() if user_data.email else None,
|
||||
nickname=user_data.nickname,
|
||||
hashed_password=hash_password(user_data.password),
|
||||
)
|
||||
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User:
|
||||
"""
|
||||
通过 ID 获取用户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: 用户不存在
|
||||
"""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
return user
|
||||
|
||||
async def get_user_by_username(self, username: str) -> User | None:
|
||||
"""
|
||||
通过用户名获取用户
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
return await self.user_repo.get_by_username(username)
|
||||
|
||||
async def update_user(
|
||||
self,
|
||||
user_id: str,
|
||||
update_data: UserUpdate,
|
||||
) -> User:
|
||||
"""
|
||||
更新用户信息
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
update_data: 更新数据
|
||||
|
||||
Returns:
|
||||
更新后的用户
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: 用户不存在
|
||||
UserAlreadyExistsError: 邮箱已被使用
|
||||
"""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
|
||||
# 检查邮箱是否被其他用户使用
|
||||
if update_data.email:
|
||||
existing_user = await self.user_repo.get_by_email(update_data.email)
|
||||
if existing_user and existing_user.id != user_id:
|
||||
raise UserAlreadyExistsError("邮箱")
|
||||
|
||||
# 准备更新数据
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
if update_dict.get("email"):
|
||||
update_dict["email"] = update_dict["email"].lower()
|
||||
|
||||
# 更新用户
|
||||
user = await self.user_repo.update(user, **update_dict)
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
async def update_last_login(self, user: User) -> None:
|
||||
"""
|
||||
更新用户最后登录时间
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
"""
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
last_login_at=datetime.now(timezone.utc),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
async def deactivate_user(self, user_id: str) -> User:
|
||||
"""
|
||||
禁用用户账户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
更新后的用户
|
||||
"""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
user = await self.user_repo.update(user, is_active=False)
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
async def activate_user(self, user_id: str) -> User:
|
||||
"""
|
||||
激活用户账户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
更新后的用户
|
||||
"""
|
||||
user = await self.get_user_by_id(user_id)
|
||||
user = await self.user_repo.update(user, is_active=True)
|
||||
await self.user_repo.commit()
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user