""" 余额仓库 处理余额相关的数据库操作。 设计说明: - 使用乐观锁(version)防止并发更新冲突 - 提供行级锁支持(悲观锁)用于关键操作 """ from datetime import datetime from typing import Any from sqlalchemy import select, update, and_ from sqlalchemy.ext.asyncio import AsyncSession from app.models.balance import ( UserBalance, BalanceTransaction, TransactionType, TransactionStatus, ) from app.repositories.base import BaseRepository class BalanceRepository(BaseRepository[UserBalance]): """余额仓库""" model = UserBalance async def get_by_user_id(self, user_id: str) -> UserBalance | None: """ 通过用户 ID 获取余额账户 Args: user_id: 用户 ID Returns: 余额账户或 None """ stmt = select(UserBalance).where(UserBalance.user_id == user_id) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_user_id_for_update(self, user_id: str) -> UserBalance | None: """ 通过用户 ID 获取余额账户(加行级锁) 用于需要原子性更新的场景,如扣款操作。 Args: user_id: 用户 ID Returns: 余额账户或 None """ stmt = ( select(UserBalance) .where(UserBalance.user_id == user_id) .with_for_update() # 行级锁 ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_or_create(self, user_id: str) -> UserBalance: """ 获取或创建余额账户 如果用户没有余额账户,自动创建一个。 Args: user_id: 用户 ID Returns: 余额账户 """ balance = await self.get_by_user_id(user_id) if balance is None: balance = await self.create(user_id=user_id) return balance async def get_or_create_for_update(self, user_id: str) -> UserBalance: """ 获取或创建余额账户(加行级锁) Args: user_id: 用户 ID Returns: 余额账户 """ balance = await self.get_by_user_id_for_update(user_id) if balance is None: balance = await self.create(user_id=user_id) # 重新获取并加锁 await self.session.flush() balance = await self.get_by_user_id_for_update(user_id) return balance # type: ignore async def update_balance_optimistic( self, balance: UserBalance, delta: int, *, is_recharge: bool = False, is_consumption: bool = False, ) -> bool: """ 使用乐观锁更新余额 通过版本号检查确保并发安全。 Args: balance: 余额账户 delta: 变化量(正数增加,负数减少) is_recharge: 是否为充值 is_consumption: 是否为消费 Returns: 是否更新成功 """ current_version = balance.version new_balance = balance.balance + delta # 构建更新语句 update_values: dict[str, Any] = { "balance": new_balance, "version": current_version + 1, } if is_recharge and delta > 0: update_values["total_recharged"] = balance.total_recharged + delta if is_consumption and delta < 0: update_values["total_consumed"] = balance.total_consumed + abs(delta) stmt = ( update(UserBalance) .where( and_( UserBalance.id == balance.id, UserBalance.version == current_version, # 乐观锁检查 ) ) .values(**update_values) ) result = await self.session.execute(stmt) if result.rowcount == 1: # 更新成功,刷新对象 balance.balance = new_balance balance.version = current_version + 1 if is_recharge and delta > 0: balance.total_recharged += delta if is_consumption and delta < 0: balance.total_consumed += abs(delta) return True return False async def freeze_balance( self, balance: UserBalance, amount: int, ) -> bool: """ 冻结余额 Args: balance: 余额账户 amount: 冻结金额(正数) Returns: 是否成功 """ if amount <= 0: return False if balance.available_balance < amount: return False current_version = balance.version stmt = ( update(UserBalance) .where( and_( UserBalance.id == balance.id, UserBalance.version == current_version, UserBalance.balance - UserBalance.frozen_balance >= amount, ) ) .values( frozen_balance=UserBalance.frozen_balance + amount, version=current_version + 1, ) ) result = await self.session.execute(stmt) if result.rowcount == 1: balance.frozen_balance += amount balance.version = current_version + 1 return True return False async def unfreeze_balance( self, balance: UserBalance, amount: int, ) -> bool: """ 解冻余额 Args: balance: 余额账户 amount: 解冻金额(正数) Returns: 是否成功 """ if amount <= 0: return False if balance.frozen_balance < amount: return False current_version = balance.version stmt = ( update(UserBalance) .where( and_( UserBalance.id == balance.id, UserBalance.version == current_version, UserBalance.frozen_balance >= amount, ) ) .values( frozen_balance=UserBalance.frozen_balance - amount, version=current_version + 1, ) ) result = await self.session.execute(stmt) if result.rowcount == 1: balance.frozen_balance -= amount balance.version = current_version + 1 return True return False class TransactionRepository(BaseRepository[BalanceTransaction]): """交易记录仓库""" model = BalanceTransaction async def get_by_user_id( self, user_id: str, *, offset: int = 0, limit: int = 20, transaction_type: TransactionType | None = None, status: TransactionStatus | None = None, ) -> list[BalanceTransaction]: """ 获取用户的交易记录 Args: user_id: 用户 ID offset: 偏移量 limit: 限制数量 transaction_type: 交易类型过滤 status: 状态过滤 Returns: 交易记录列表 """ stmt = select(BalanceTransaction).where( BalanceTransaction.user_id == user_id ) if transaction_type: stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type) if status: stmt = stmt.where(BalanceTransaction.status == status) stmt = ( stmt .order_by(BalanceTransaction.created_at.desc()) .offset(offset) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def count_by_user_id( self, user_id: str, *, transaction_type: TransactionType | None = None, status: TransactionStatus | None = None, ) -> int: """ 统计用户的交易记录数量 Args: user_id: 用户 ID transaction_type: 交易类型过滤 status: 状态过滤 Returns: 记录数量 """ from sqlalchemy import func stmt = select(func.count()).select_from(BalanceTransaction).where( BalanceTransaction.user_id == user_id ) if transaction_type: stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type) if status: stmt = stmt.where(BalanceTransaction.status == status) result = await self.session.execute(stmt) return result.scalar() or 0 async def get_by_idempotency_key( self, idempotency_key: str, ) -> BalanceTransaction | None: """ 通过幂等键获取交易记录 用于防止重复提交。 Args: idempotency_key: 幂等键 Returns: 交易记录或 None """ stmt = select(BalanceTransaction).where( BalanceTransaction.idempotency_key == idempotency_key ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_reference( self, reference_type: str, reference_id: str, ) -> list[BalanceTransaction]: """ 通过业务关联获取交易记录 Args: reference_type: 关联业务类型 reference_id: 关联业务 ID Returns: 交易记录列表 """ stmt = ( select(BalanceTransaction) .where( and_( BalanceTransaction.reference_type == reference_type, BalanceTransaction.reference_id == reference_id, ) ) .order_by(BalanceTransaction.created_at.desc()) ) result = await self.session.execute(stmt) return list(result.scalars().all())