463 lines
13 KiB
Python
463 lines
13 KiB
Python
"""
|
|
兑换码仓库
|
|
|
|
处理兑换码相关的数据库操作。
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select, update, and_, or_, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.redeem_code import (
|
|
RedeemCode,
|
|
RedeemCodeBatch,
|
|
RedeemCodeUsageLog,
|
|
RedeemCodeStatus,
|
|
)
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class RedeemCodeRepository(BaseRepository[RedeemCode]):
|
|
"""兑换码仓库"""
|
|
|
|
model = RedeemCode
|
|
|
|
async def get_by_code(self, code: str) -> RedeemCode | None:
|
|
"""
|
|
通过兑换码获取记录
|
|
|
|
Args:
|
|
code: 兑换码
|
|
|
|
Returns:
|
|
兑换码记录或 None
|
|
"""
|
|
# 标准化兑换码格式
|
|
normalized_code = code.strip().upper().replace(" ", "")
|
|
stmt = select(RedeemCode).where(RedeemCode.code == normalized_code)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_code_for_update(self, code: str) -> RedeemCode | None:
|
|
"""
|
|
通过兑换码获取记录(加行级锁)
|
|
|
|
用于兑换操作,防止并发兑换。
|
|
|
|
Args:
|
|
code: 兑换码
|
|
|
|
Returns:
|
|
兑换码记录或 None
|
|
"""
|
|
normalized_code = code.strip().upper().replace(" ", "")
|
|
stmt = (
|
|
select(RedeemCode)
|
|
.where(RedeemCode.code == normalized_code)
|
|
.with_for_update()
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_valid_code(self, code: str) -> RedeemCode | None:
|
|
"""
|
|
获取有效的兑换码
|
|
|
|
检查状态、使用次数和有效期。
|
|
|
|
Args:
|
|
code: 兑换码
|
|
|
|
Returns:
|
|
有效的兑换码或 None
|
|
"""
|
|
normalized_code = code.strip().upper().replace(" ", "")
|
|
now = datetime.now(timezone.utc)
|
|
|
|
stmt = (
|
|
select(RedeemCode)
|
|
.where(
|
|
and_(
|
|
RedeemCode.code == normalized_code,
|
|
RedeemCode.status == RedeemCodeStatus.ACTIVE,
|
|
RedeemCode.used_count < RedeemCode.max_uses,
|
|
or_(
|
|
RedeemCode.expires_at.is_(None),
|
|
RedeemCode.expires_at > now,
|
|
),
|
|
)
|
|
)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def mark_as_used(
|
|
self,
|
|
code: RedeemCode,
|
|
user_id: str,
|
|
) -> bool:
|
|
"""
|
|
标记兑换码已使用
|
|
|
|
更新使用计数和状态。
|
|
|
|
Args:
|
|
code: 兑换码记录
|
|
user_id: 使用者 ID
|
|
|
|
Returns:
|
|
是否成功
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
new_used_count = code.used_count + 1
|
|
new_status = (
|
|
RedeemCodeStatus.USED
|
|
if new_used_count >= code.max_uses
|
|
else RedeemCodeStatus.ACTIVE
|
|
)
|
|
|
|
stmt = (
|
|
update(RedeemCode)
|
|
.where(
|
|
and_(
|
|
RedeemCode.id == code.id,
|
|
RedeemCode.used_count == code.used_count, # 乐观锁
|
|
)
|
|
)
|
|
.values(
|
|
used_count=new_used_count,
|
|
status=new_status,
|
|
used_by=user_id,
|
|
used_at=now,
|
|
)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
|
|
if result.rowcount == 1:
|
|
code.used_count = new_used_count
|
|
code.status = new_status
|
|
code.used_by = user_id
|
|
code.used_at = now
|
|
return True
|
|
|
|
return False
|
|
|
|
async def get_all_with_filters(
|
|
self,
|
|
*,
|
|
offset: int = 0,
|
|
limit: int = 20,
|
|
status: RedeemCodeStatus | None = None,
|
|
batch_id: str | None = None,
|
|
code_like: str | None = None,
|
|
created_after: datetime | None = None,
|
|
created_before: datetime | None = None,
|
|
) -> list[RedeemCode]:
|
|
"""
|
|
获取兑换码列表(支持过滤)
|
|
|
|
Args:
|
|
offset: 偏移量
|
|
limit: 限制数量
|
|
status: 状态过滤
|
|
batch_id: 批次 ID 过滤
|
|
code_like: 兑换码模糊匹配
|
|
created_after: 创建时间起始
|
|
created_before: 创建时间结束
|
|
|
|
Returns:
|
|
兑换码列表
|
|
"""
|
|
stmt = select(RedeemCode)
|
|
|
|
conditions = []
|
|
if status:
|
|
conditions.append(RedeemCode.status == status)
|
|
if batch_id:
|
|
conditions.append(RedeemCode.batch_id == batch_id)
|
|
if code_like:
|
|
conditions.append(RedeemCode.code.contains(code_like.upper()))
|
|
if created_after:
|
|
conditions.append(RedeemCode.created_at >= created_after)
|
|
if created_before:
|
|
conditions.append(RedeemCode.created_at <= created_before)
|
|
|
|
if conditions:
|
|
stmt = stmt.where(and_(*conditions))
|
|
|
|
stmt = (
|
|
stmt
|
|
.order_by(RedeemCode.created_at.desc())
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_with_filters(
|
|
self,
|
|
*,
|
|
status: RedeemCodeStatus | None = None,
|
|
batch_id: str | None = None,
|
|
code_like: str | None = None,
|
|
created_after: datetime | None = None,
|
|
created_before: datetime | None = None,
|
|
) -> int:
|
|
"""
|
|
统计兑换码数量(支持过滤)
|
|
"""
|
|
stmt = select(func.count()).select_from(RedeemCode)
|
|
|
|
conditions = []
|
|
if status:
|
|
conditions.append(RedeemCode.status == status)
|
|
if batch_id:
|
|
conditions.append(RedeemCode.batch_id == batch_id)
|
|
if code_like:
|
|
conditions.append(RedeemCode.code.contains(code_like.upper()))
|
|
if created_after:
|
|
conditions.append(RedeemCode.created_at >= created_after)
|
|
if created_before:
|
|
conditions.append(RedeemCode.created_at <= created_before)
|
|
|
|
if conditions:
|
|
stmt = stmt.where(and_(*conditions))
|
|
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar() or 0
|
|
|
|
async def bulk_create(
|
|
self,
|
|
codes_data: list[dict[str, Any]],
|
|
) -> list[RedeemCode]:
|
|
"""
|
|
批量创建兑换码
|
|
|
|
Args:
|
|
codes_data: 兑换码数据列表
|
|
|
|
Returns:
|
|
创建的兑换码列表
|
|
"""
|
|
codes = [RedeemCode(**data) for data in codes_data]
|
|
self.session.add_all(codes)
|
|
await self.session.flush()
|
|
return codes
|
|
|
|
async def disable_code(self, code: RedeemCode) -> RedeemCode:
|
|
"""
|
|
禁用兑换码
|
|
|
|
Args:
|
|
code: 兑换码记录
|
|
|
|
Returns:
|
|
更新后的兑换码
|
|
"""
|
|
return await self.update(code, status=RedeemCodeStatus.DISABLED)
|
|
|
|
async def enable_code(self, code: RedeemCode) -> RedeemCode:
|
|
"""
|
|
启用兑换码
|
|
|
|
Args:
|
|
code: 兑换码记录
|
|
|
|
Returns:
|
|
更新后的兑换码
|
|
"""
|
|
# 只有禁用状态的可以重新启用
|
|
if code.status != RedeemCodeStatus.DISABLED:
|
|
return code
|
|
|
|
# 如果使用次数已满,改为已使用状态
|
|
if code.used_count >= code.max_uses:
|
|
return await self.update(code, status=RedeemCodeStatus.USED)
|
|
|
|
return await self.update(code, status=RedeemCodeStatus.ACTIVE)
|
|
|
|
|
|
class RedeemCodeBatchRepository(BaseRepository[RedeemCodeBatch]):
|
|
"""兑换码批次仓库"""
|
|
|
|
model = RedeemCodeBatch
|
|
|
|
async def get_all_batches(
|
|
self,
|
|
*,
|
|
offset: int = 0,
|
|
limit: int = 20,
|
|
) -> list[RedeemCodeBatch]:
|
|
"""
|
|
获取所有批次
|
|
|
|
Args:
|
|
offset: 偏移量
|
|
limit: 限制数量
|
|
|
|
Returns:
|
|
批次列表
|
|
"""
|
|
stmt = (
|
|
select(RedeemCodeBatch)
|
|
.order_by(RedeemCodeBatch.created_at.desc())
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def increment_used_count(self, batch_id: str) -> None:
|
|
"""
|
|
增加批次已使用计数
|
|
|
|
Args:
|
|
batch_id: 批次 ID
|
|
"""
|
|
stmt = (
|
|
update(RedeemCodeBatch)
|
|
.where(RedeemCodeBatch.id == batch_id)
|
|
.values(used_count=RedeemCodeBatch.used_count + 1)
|
|
)
|
|
await self.session.execute(stmt)
|
|
|
|
|
|
class RedeemCodeUsageLogRepository(BaseRepository[RedeemCodeUsageLog]):
|
|
"""兑换码使用日志仓库"""
|
|
|
|
model = RedeemCodeUsageLog
|
|
|
|
async def get_by_code_id(
|
|
self,
|
|
redeem_code_id: str,
|
|
*,
|
|
offset: int = 0,
|
|
limit: int = 20,
|
|
) -> list[RedeemCodeUsageLog]:
|
|
"""
|
|
获取兑换码的使用日志
|
|
|
|
Args:
|
|
redeem_code_id: 兑换码 ID
|
|
offset: 偏移量
|
|
limit: 限制数量
|
|
|
|
Returns:
|
|
使用日志列表
|
|
"""
|
|
stmt = (
|
|
select(RedeemCodeUsageLog)
|
|
.where(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
|
.order_by(RedeemCodeUsageLog.created_at.desc())
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_by_user_id(
|
|
self,
|
|
user_id: str,
|
|
*,
|
|
offset: int = 0,
|
|
limit: int = 20,
|
|
) -> list[RedeemCodeUsageLog]:
|
|
"""
|
|
获取用户的兑换日志
|
|
|
|
Args:
|
|
user_id: 用户 ID
|
|
offset: 偏移量
|
|
limit: 限制数量
|
|
|
|
Returns:
|
|
使用日志列表
|
|
"""
|
|
stmt = (
|
|
select(RedeemCodeUsageLog)
|
|
.where(RedeemCodeUsageLog.user_id == user_id)
|
|
.order_by(RedeemCodeUsageLog.created_at.desc())
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_all_with_filters(
|
|
self,
|
|
*,
|
|
offset: int = 0,
|
|
limit: int = 20,
|
|
redeem_code_id: str | None = None,
|
|
user_id: str | None = None,
|
|
code_like: str | None = None,
|
|
created_after: datetime | None = None,
|
|
created_before: datetime | None = None,
|
|
) -> list[RedeemCodeUsageLog]:
|
|
"""
|
|
获取使用日志(支持过滤)
|
|
"""
|
|
stmt = select(RedeemCodeUsageLog)
|
|
|
|
conditions = []
|
|
if redeem_code_id:
|
|
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
|
if user_id:
|
|
conditions.append(RedeemCodeUsageLog.user_id == user_id)
|
|
if code_like:
|
|
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
|
|
if created_after:
|
|
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
|
|
if created_before:
|
|
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
|
|
|
|
if conditions:
|
|
stmt = stmt.where(and_(*conditions))
|
|
|
|
stmt = (
|
|
stmt
|
|
.order_by(RedeemCodeUsageLog.created_at.desc())
|
|
.offset(offset)
|
|
.limit(limit)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_with_filters(
|
|
self,
|
|
*,
|
|
redeem_code_id: str | None = None,
|
|
user_id: str | None = None,
|
|
code_like: str | None = None,
|
|
created_after: datetime | None = None,
|
|
created_before: datetime | None = None,
|
|
) -> int:
|
|
"""
|
|
统计使用日志数量
|
|
"""
|
|
stmt = select(func.count()).select_from(RedeemCodeUsageLog)
|
|
|
|
conditions = []
|
|
if redeem_code_id:
|
|
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
|
if user_id:
|
|
conditions.append(RedeemCodeUsageLog.user_id == user_id)
|
|
if code_like:
|
|
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
|
|
if created_after:
|
|
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
|
|
if created_before:
|
|
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
|
|
|
|
if conditions:
|
|
stmt = stmt.where(and_(*conditions))
|
|
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar() or 0
|
|
|