提供基本前后端骨架

This commit is contained in:
hisatri
2026-01-06 23:49:23 +08:00
parent 84d4ccc226
commit 06f8176e23
89 changed files with 19293 additions and 2 deletions

378
app/repositories/balance.py Normal file
View File

@@ -0,0 +1,378 @@
"""
余额仓库
处理余额相关的数据库操作。
设计说明:
- 使用乐观锁version防止并发更新冲突
- 提供行级锁支持(悲观锁)用于关键操作
"""
from datetime import datetime
from typing import Any
from sqlalchemy import select, update, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.balance import (
UserBalance,
BalanceTransaction,
TransactionType,
TransactionStatus,
)
from app.repositories.base import BaseRepository
class BalanceRepository(BaseRepository[UserBalance]):
"""余额仓库"""
model = UserBalance
async def get_by_user_id(self, user_id: str) -> UserBalance | None:
"""
通过用户 ID 获取余额账户
Args:
user_id: 用户 ID
Returns:
余额账户或 None
"""
stmt = select(UserBalance).where(UserBalance.user_id == user_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_user_id_for_update(self, user_id: str) -> UserBalance | None:
"""
通过用户 ID 获取余额账户(加行级锁)
用于需要原子性更新的场景,如扣款操作。
Args:
user_id: 用户 ID
Returns:
余额账户或 None
"""
stmt = (
select(UserBalance)
.where(UserBalance.user_id == user_id)
.with_for_update() # 行级锁
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_or_create(self, user_id: str) -> UserBalance:
"""
获取或创建余额账户
如果用户没有余额账户,自动创建一个。
Args:
user_id: 用户 ID
Returns:
余额账户
"""
balance = await self.get_by_user_id(user_id)
if balance is None:
balance = await self.create(user_id=user_id)
return balance
async def get_or_create_for_update(self, user_id: str) -> UserBalance:
"""
获取或创建余额账户(加行级锁)
Args:
user_id: 用户 ID
Returns:
余额账户
"""
balance = await self.get_by_user_id_for_update(user_id)
if balance is None:
balance = await self.create(user_id=user_id)
# 重新获取并加锁
await self.session.flush()
balance = await self.get_by_user_id_for_update(user_id)
return balance # type: ignore
async def update_balance_optimistic(
self,
balance: UserBalance,
delta: int,
*,
is_recharge: bool = False,
is_consumption: bool = False,
) -> bool:
"""
使用乐观锁更新余额
通过版本号检查确保并发安全。
Args:
balance: 余额账户
delta: 变化量(正数增加,负数减少)
is_recharge: 是否为充值
is_consumption: 是否为消费
Returns:
是否更新成功
"""
current_version = balance.version
new_balance = balance.balance + delta
# 构建更新语句
update_values: dict[str, Any] = {
"balance": new_balance,
"version": current_version + 1,
}
if is_recharge and delta > 0:
update_values["total_recharged"] = balance.total_recharged + delta
if is_consumption and delta < 0:
update_values["total_consumed"] = balance.total_consumed + abs(delta)
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version, # 乐观锁检查
)
)
.values(**update_values)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
# 更新成功,刷新对象
balance.balance = new_balance
balance.version = current_version + 1
if is_recharge and delta > 0:
balance.total_recharged += delta
if is_consumption and delta < 0:
balance.total_consumed += abs(delta)
return True
return False
async def freeze_balance(
self,
balance: UserBalance,
amount: int,
) -> bool:
"""
冻结余额
Args:
balance: 余额账户
amount: 冻结金额(正数)
Returns:
是否成功
"""
if amount <= 0:
return False
if balance.available_balance < amount:
return False
current_version = balance.version
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version,
UserBalance.balance - UserBalance.frozen_balance >= amount,
)
)
.values(
frozen_balance=UserBalance.frozen_balance + amount,
version=current_version + 1,
)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
balance.frozen_balance += amount
balance.version = current_version + 1
return True
return False
async def unfreeze_balance(
self,
balance: UserBalance,
amount: int,
) -> bool:
"""
解冻余额
Args:
balance: 余额账户
amount: 解冻金额(正数)
Returns:
是否成功
"""
if amount <= 0:
return False
if balance.frozen_balance < amount:
return False
current_version = balance.version
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version,
UserBalance.frozen_balance >= amount,
)
)
.values(
frozen_balance=UserBalance.frozen_balance - amount,
version=current_version + 1,
)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
balance.frozen_balance -= amount
balance.version = current_version + 1
return True
return False
class TransactionRepository(BaseRepository[BalanceTransaction]):
"""交易记录仓库"""
model = BalanceTransaction
async def get_by_user_id(
self,
user_id: str,
*,
offset: int = 0,
limit: int = 20,
transaction_type: TransactionType | None = None,
status: TransactionStatus | None = None,
) -> list[BalanceTransaction]:
"""
获取用户的交易记录
Args:
user_id: 用户 ID
offset: 偏移量
limit: 限制数量
transaction_type: 交易类型过滤
status: 状态过滤
Returns:
交易记录列表
"""
stmt = select(BalanceTransaction).where(
BalanceTransaction.user_id == user_id
)
if transaction_type:
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
if status:
stmt = stmt.where(BalanceTransaction.status == status)
stmt = (
stmt
.order_by(BalanceTransaction.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count_by_user_id(
self,
user_id: str,
*,
transaction_type: TransactionType | None = None,
status: TransactionStatus | None = None,
) -> int:
"""
统计用户的交易记录数量
Args:
user_id: 用户 ID
transaction_type: 交易类型过滤
status: 状态过滤
Returns:
记录数量
"""
from sqlalchemy import func
stmt = select(func.count()).select_from(BalanceTransaction).where(
BalanceTransaction.user_id == user_id
)
if transaction_type:
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
if status:
stmt = stmt.where(BalanceTransaction.status == status)
result = await self.session.execute(stmt)
return result.scalar() or 0
async def get_by_idempotency_key(
self,
idempotency_key: str,
) -> BalanceTransaction | None:
"""
通过幂等键获取交易记录
用于防止重复提交。
Args:
idempotency_key: 幂等键
Returns:
交易记录或 None
"""
stmt = select(BalanceTransaction).where(
BalanceTransaction.idempotency_key == idempotency_key
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_reference(
self,
reference_type: str,
reference_id: str,
) -> list[BalanceTransaction]:
"""
通过业务关联获取交易记录
Args:
reference_type: 关联业务类型
reference_id: 关联业务 ID
Returns:
交易记录列表
"""
stmt = (
select(BalanceTransaction)
.where(
and_(
BalanceTransaction.reference_type == reference_type,
BalanceTransaction.reference_id == reference_id,
)
)
.order_by(BalanceTransaction.created_at.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())