""" 兑换码仓库 处理兑换码相关的数据库操作。 """ from datetime import datetime, timezone from typing import Any from sqlalchemy import select, update, and_, or_, func from sqlalchemy.ext.asyncio import AsyncSession from app.models.redeem_code import ( RedeemCode, RedeemCodeBatch, RedeemCodeUsageLog, RedeemCodeStatus, ) from app.repositories.base import BaseRepository class RedeemCodeRepository(BaseRepository[RedeemCode]): """兑换码仓库""" model = RedeemCode async def get_by_code(self, code: str) -> RedeemCode | None: """ 通过兑换码获取记录 Args: code: 兑换码 Returns: 兑换码记录或 None """ # 标准化兑换码格式 normalized_code = code.strip().upper().replace(" ", "") stmt = select(RedeemCode).where(RedeemCode.code == normalized_code) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_code_for_update(self, code: str) -> RedeemCode | None: """ 通过兑换码获取记录(加行级锁) 用于兑换操作,防止并发兑换。 Args: code: 兑换码 Returns: 兑换码记录或 None """ normalized_code = code.strip().upper().replace(" ", "") stmt = ( select(RedeemCode) .where(RedeemCode.code == normalized_code) .with_for_update() ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_valid_code(self, code: str) -> RedeemCode | None: """ 获取有效的兑换码 检查状态、使用次数和有效期。 Args: code: 兑换码 Returns: 有效的兑换码或 None """ normalized_code = code.strip().upper().replace(" ", "") now = datetime.now(timezone.utc) stmt = ( select(RedeemCode) .where( and_( RedeemCode.code == normalized_code, RedeemCode.status == RedeemCodeStatus.ACTIVE, RedeemCode.used_count < RedeemCode.max_uses, or_( RedeemCode.expires_at.is_(None), RedeemCode.expires_at > now, ), ) ) ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def mark_as_used( self, code: RedeemCode, user_id: str, ) -> bool: """ 标记兑换码已使用 更新使用计数和状态。 Args: code: 兑换码记录 user_id: 使用者 ID Returns: 是否成功 """ now = datetime.now(timezone.utc) new_used_count = code.used_count + 1 new_status = ( RedeemCodeStatus.USED if new_used_count >= code.max_uses else RedeemCodeStatus.ACTIVE ) stmt = ( update(RedeemCode) .where( and_( RedeemCode.id == code.id, RedeemCode.used_count == code.used_count, # 乐观锁 ) ) .values( used_count=new_used_count, status=new_status, used_by=user_id, used_at=now, ) ) result = await self.session.execute(stmt) if result.rowcount == 1: code.used_count = new_used_count code.status = new_status code.used_by = user_id code.used_at = now return True return False async def get_all_with_filters( self, *, offset: int = 0, limit: int = 20, status: RedeemCodeStatus | None = None, batch_id: str | None = None, code_like: str | None = None, created_after: datetime | None = None, created_before: datetime | None = None, ) -> list[RedeemCode]: """ 获取兑换码列表(支持过滤) Args: offset: 偏移量 limit: 限制数量 status: 状态过滤 batch_id: 批次 ID 过滤 code_like: 兑换码模糊匹配 created_after: 创建时间起始 created_before: 创建时间结束 Returns: 兑换码列表 """ stmt = select(RedeemCode) conditions = [] if status: conditions.append(RedeemCode.status == status) if batch_id: conditions.append(RedeemCode.batch_id == batch_id) if code_like: conditions.append(RedeemCode.code.contains(code_like.upper())) if created_after: conditions.append(RedeemCode.created_at >= created_after) if created_before: conditions.append(RedeemCode.created_at <= created_before) if conditions: stmt = stmt.where(and_(*conditions)) stmt = ( stmt .order_by(RedeemCode.created_at.desc()) .offset(offset) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def count_with_filters( self, *, status: RedeemCodeStatus | None = None, batch_id: str | None = None, code_like: str | None = None, created_after: datetime | None = None, created_before: datetime | None = None, ) -> int: """ 统计兑换码数量(支持过滤) """ stmt = select(func.count()).select_from(RedeemCode) conditions = [] if status: conditions.append(RedeemCode.status == status) if batch_id: conditions.append(RedeemCode.batch_id == batch_id) if code_like: conditions.append(RedeemCode.code.contains(code_like.upper())) if created_after: conditions.append(RedeemCode.created_at >= created_after) if created_before: conditions.append(RedeemCode.created_at <= created_before) if conditions: stmt = stmt.where(and_(*conditions)) result = await self.session.execute(stmt) return result.scalar() or 0 async def bulk_create( self, codes_data: list[dict[str, Any]], ) -> list[RedeemCode]: """ 批量创建兑换码 Args: codes_data: 兑换码数据列表 Returns: 创建的兑换码列表 """ codes = [RedeemCode(**data) for data in codes_data] self.session.add_all(codes) await self.session.flush() return codes async def disable_code(self, code: RedeemCode) -> RedeemCode: """ 禁用兑换码 Args: code: 兑换码记录 Returns: 更新后的兑换码 """ return await self.update(code, status=RedeemCodeStatus.DISABLED) async def enable_code(self, code: RedeemCode) -> RedeemCode: """ 启用兑换码 Args: code: 兑换码记录 Returns: 更新后的兑换码 """ # 只有禁用状态的可以重新启用 if code.status != RedeemCodeStatus.DISABLED: return code # 如果使用次数已满,改为已使用状态 if code.used_count >= code.max_uses: return await self.update(code, status=RedeemCodeStatus.USED) return await self.update(code, status=RedeemCodeStatus.ACTIVE) class RedeemCodeBatchRepository(BaseRepository[RedeemCodeBatch]): """兑换码批次仓库""" model = RedeemCodeBatch async def get_all_batches( self, *, offset: int = 0, limit: int = 20, ) -> list[RedeemCodeBatch]: """ 获取所有批次 Args: offset: 偏移量 limit: 限制数量 Returns: 批次列表 """ stmt = ( select(RedeemCodeBatch) .order_by(RedeemCodeBatch.created_at.desc()) .offset(offset) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def increment_used_count(self, batch_id: str) -> None: """ 增加批次已使用计数 Args: batch_id: 批次 ID """ stmt = ( update(RedeemCodeBatch) .where(RedeemCodeBatch.id == batch_id) .values(used_count=RedeemCodeBatch.used_count + 1) ) await self.session.execute(stmt) class RedeemCodeUsageLogRepository(BaseRepository[RedeemCodeUsageLog]): """兑换码使用日志仓库""" model = RedeemCodeUsageLog async def get_by_code_id( self, redeem_code_id: str, *, offset: int = 0, limit: int = 20, ) -> list[RedeemCodeUsageLog]: """ 获取兑换码的使用日志 Args: redeem_code_id: 兑换码 ID offset: 偏移量 limit: 限制数量 Returns: 使用日志列表 """ stmt = ( select(RedeemCodeUsageLog) .where(RedeemCodeUsageLog.redeem_code_id == redeem_code_id) .order_by(RedeemCodeUsageLog.created_at.desc()) .offset(offset) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get_by_user_id( self, user_id: str, *, offset: int = 0, limit: int = 20, ) -> list[RedeemCodeUsageLog]: """ 获取用户的兑换日志 Args: user_id: 用户 ID offset: 偏移量 limit: 限制数量 Returns: 使用日志列表 """ stmt = ( select(RedeemCodeUsageLog) .where(RedeemCodeUsageLog.user_id == user_id) .order_by(RedeemCodeUsageLog.created_at.desc()) .offset(offset) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get_all_with_filters( self, *, offset: int = 0, limit: int = 20, redeem_code_id: str | None = None, user_id: str | None = None, code_like: str | None = None, created_after: datetime | None = None, created_before: datetime | None = None, ) -> list[RedeemCodeUsageLog]: """ 获取使用日志(支持过滤) """ stmt = select(RedeemCodeUsageLog) conditions = [] if redeem_code_id: conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id) if user_id: conditions.append(RedeemCodeUsageLog.user_id == user_id) if code_like: conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper())) if created_after: conditions.append(RedeemCodeUsageLog.created_at >= created_after) if created_before: conditions.append(RedeemCodeUsageLog.created_at <= created_before) if conditions: stmt = stmt.where(and_(*conditions)) stmt = ( stmt .order_by(RedeemCodeUsageLog.created_at.desc()) .offset(offset) .limit(limit) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def count_with_filters( self, *, redeem_code_id: str | None = None, user_id: str | None = None, code_like: str | None = None, created_after: datetime | None = None, created_before: datetime | None = None, ) -> int: """ 统计使用日志数量 """ stmt = select(func.count()).select_from(RedeemCodeUsageLog) conditions = [] if redeem_code_id: conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id) if user_id: conditions.append(RedeemCodeUsageLog.user_id == user_id) if code_like: conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper())) if created_after: conditions.append(RedeemCodeUsageLog.created_at >= created_after) if created_before: conditions.append(RedeemCodeUsageLog.created_at <= created_before) if conditions: stmt = stmt.where(and_(*conditions)) result = await self.session.execute(stmt) return result.scalar() or 0