Files
SatoNano/app/repositories/balance.py
2026-01-06 23:49:23 +08:00

379 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
余额仓库
处理余额相关的数据库操作。
设计说明:
- 使用乐观锁version防止并发更新冲突
- 提供行级锁支持(悲观锁)用于关键操作
"""
from datetime import datetime
from typing import Any
from sqlalchemy import select, update, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.balance import (
UserBalance,
BalanceTransaction,
TransactionType,
TransactionStatus,
)
from app.repositories.base import BaseRepository
class BalanceRepository(BaseRepository[UserBalance]):
"""余额仓库"""
model = UserBalance
async def get_by_user_id(self, user_id: str) -> UserBalance | None:
"""
通过用户 ID 获取余额账户
Args:
user_id: 用户 ID
Returns:
余额账户或 None
"""
stmt = select(UserBalance).where(UserBalance.user_id == user_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_user_id_for_update(self, user_id: str) -> UserBalance | None:
"""
通过用户 ID 获取余额账户(加行级锁)
用于需要原子性更新的场景,如扣款操作。
Args:
user_id: 用户 ID
Returns:
余额账户或 None
"""
stmt = (
select(UserBalance)
.where(UserBalance.user_id == user_id)
.with_for_update() # 行级锁
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_or_create(self, user_id: str) -> UserBalance:
"""
获取或创建余额账户
如果用户没有余额账户,自动创建一个。
Args:
user_id: 用户 ID
Returns:
余额账户
"""
balance = await self.get_by_user_id(user_id)
if balance is None:
balance = await self.create(user_id=user_id)
return balance
async def get_or_create_for_update(self, user_id: str) -> UserBalance:
"""
获取或创建余额账户(加行级锁)
Args:
user_id: 用户 ID
Returns:
余额账户
"""
balance = await self.get_by_user_id_for_update(user_id)
if balance is None:
balance = await self.create(user_id=user_id)
# 重新获取并加锁
await self.session.flush()
balance = await self.get_by_user_id_for_update(user_id)
return balance # type: ignore
async def update_balance_optimistic(
self,
balance: UserBalance,
delta: int,
*,
is_recharge: bool = False,
is_consumption: bool = False,
) -> bool:
"""
使用乐观锁更新余额
通过版本号检查确保并发安全。
Args:
balance: 余额账户
delta: 变化量(正数增加,负数减少)
is_recharge: 是否为充值
is_consumption: 是否为消费
Returns:
是否更新成功
"""
current_version = balance.version
new_balance = balance.balance + delta
# 构建更新语句
update_values: dict[str, Any] = {
"balance": new_balance,
"version": current_version + 1,
}
if is_recharge and delta > 0:
update_values["total_recharged"] = balance.total_recharged + delta
if is_consumption and delta < 0:
update_values["total_consumed"] = balance.total_consumed + abs(delta)
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version, # 乐观锁检查
)
)
.values(**update_values)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
# 更新成功,刷新对象
balance.balance = new_balance
balance.version = current_version + 1
if is_recharge and delta > 0:
balance.total_recharged += delta
if is_consumption and delta < 0:
balance.total_consumed += abs(delta)
return True
return False
async def freeze_balance(
self,
balance: UserBalance,
amount: int,
) -> bool:
"""
冻结余额
Args:
balance: 余额账户
amount: 冻结金额(正数)
Returns:
是否成功
"""
if amount <= 0:
return False
if balance.available_balance < amount:
return False
current_version = balance.version
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version,
UserBalance.balance - UserBalance.frozen_balance >= amount,
)
)
.values(
frozen_balance=UserBalance.frozen_balance + amount,
version=current_version + 1,
)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
balance.frozen_balance += amount
balance.version = current_version + 1
return True
return False
async def unfreeze_balance(
self,
balance: UserBalance,
amount: int,
) -> bool:
"""
解冻余额
Args:
balance: 余额账户
amount: 解冻金额(正数)
Returns:
是否成功
"""
if amount <= 0:
return False
if balance.frozen_balance < amount:
return False
current_version = balance.version
stmt = (
update(UserBalance)
.where(
and_(
UserBalance.id == balance.id,
UserBalance.version == current_version,
UserBalance.frozen_balance >= amount,
)
)
.values(
frozen_balance=UserBalance.frozen_balance - amount,
version=current_version + 1,
)
)
result = await self.session.execute(stmt)
if result.rowcount == 1:
balance.frozen_balance -= amount
balance.version = current_version + 1
return True
return False
class TransactionRepository(BaseRepository[BalanceTransaction]):
"""交易记录仓库"""
model = BalanceTransaction
async def get_by_user_id(
self,
user_id: str,
*,
offset: int = 0,
limit: int = 20,
transaction_type: TransactionType | None = None,
status: TransactionStatus | None = None,
) -> list[BalanceTransaction]:
"""
获取用户的交易记录
Args:
user_id: 用户 ID
offset: 偏移量
limit: 限制数量
transaction_type: 交易类型过滤
status: 状态过滤
Returns:
交易记录列表
"""
stmt = select(BalanceTransaction).where(
BalanceTransaction.user_id == user_id
)
if transaction_type:
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
if status:
stmt = stmt.where(BalanceTransaction.status == status)
stmt = (
stmt
.order_by(BalanceTransaction.created_at.desc())
.offset(offset)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count_by_user_id(
self,
user_id: str,
*,
transaction_type: TransactionType | None = None,
status: TransactionStatus | None = None,
) -> int:
"""
统计用户的交易记录数量
Args:
user_id: 用户 ID
transaction_type: 交易类型过滤
status: 状态过滤
Returns:
记录数量
"""
from sqlalchemy import func
stmt = select(func.count()).select_from(BalanceTransaction).where(
BalanceTransaction.user_id == user_id
)
if transaction_type:
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
if status:
stmt = stmt.where(BalanceTransaction.status == status)
result = await self.session.execute(stmt)
return result.scalar() or 0
async def get_by_idempotency_key(
self,
idempotency_key: str,
) -> BalanceTransaction | None:
"""
通过幂等键获取交易记录
用于防止重复提交。
Args:
idempotency_key: 幂等键
Returns:
交易记录或 None
"""
stmt = select(BalanceTransaction).where(
BalanceTransaction.idempotency_key == idempotency_key
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_reference(
self,
reference_type: str,
reference_id: str,
) -> list[BalanceTransaction]:
"""
通过业务关联获取交易记录
Args:
reference_type: 关联业务类型
reference_id: 关联业务 ID
Returns:
交易记录列表
"""
stmt = (
select(BalanceTransaction)
.where(
and_(
BalanceTransaction.reference_type == reference_type,
BalanceTransaction.reference_id == reference_id,
)
)
.order_by(BalanceTransaction.created_at.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())