提供基本前后端骨架
This commit is contained in:
19
app/repositories/__init__.py
Normal file
19
app/repositories/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""数据仓库层"""
|
||||
|
||||
from app.repositories.user import UserRepository
|
||||
from app.repositories.balance import BalanceRepository, TransactionRepository
|
||||
from app.repositories.redeem_code import (
|
||||
RedeemCodeRepository,
|
||||
RedeemCodeBatchRepository,
|
||||
RedeemCodeUsageLogRepository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"UserRepository",
|
||||
"BalanceRepository",
|
||||
"TransactionRepository",
|
||||
"RedeemCodeRepository",
|
||||
"RedeemCodeBatchRepository",
|
||||
"RedeemCodeUsageLogRepository",
|
||||
]
|
||||
|
||||
378
app/repositories/balance.py
Normal file
378
app/repositories/balance.py
Normal file
@@ -0,0 +1,378 @@
|
||||
"""
|
||||
余额仓库
|
||||
|
||||
处理余额相关的数据库操作。
|
||||
|
||||
设计说明:
|
||||
- 使用乐观锁(version)防止并发更新冲突
|
||||
- 提供行级锁支持(悲观锁)用于关键操作
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.balance import (
|
||||
UserBalance,
|
||||
BalanceTransaction,
|
||||
TransactionType,
|
||||
TransactionStatus,
|
||||
)
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class BalanceRepository(BaseRepository[UserBalance]):
|
||||
"""余额仓库"""
|
||||
|
||||
model = UserBalance
|
||||
|
||||
async def get_by_user_id(self, user_id: str) -> UserBalance | None:
|
||||
"""
|
||||
通过用户 ID 获取余额账户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户或 None
|
||||
"""
|
||||
stmt = select(UserBalance).where(UserBalance.user_id == user_id)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_user_id_for_update(self, user_id: str) -> UserBalance | None:
|
||||
"""
|
||||
通过用户 ID 获取余额账户(加行级锁)
|
||||
|
||||
用于需要原子性更新的场景,如扣款操作。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户或 None
|
||||
"""
|
||||
stmt = (
|
||||
select(UserBalance)
|
||||
.where(UserBalance.user_id == user_id)
|
||||
.with_for_update() # 行级锁
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_or_create(self, user_id: str) -> UserBalance:
|
||||
"""
|
||||
获取或创建余额账户
|
||||
|
||||
如果用户没有余额账户,自动创建一个。
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户
|
||||
"""
|
||||
balance = await self.get_by_user_id(user_id)
|
||||
if balance is None:
|
||||
balance = await self.create(user_id=user_id)
|
||||
return balance
|
||||
|
||||
async def get_or_create_for_update(self, user_id: str) -> UserBalance:
|
||||
"""
|
||||
获取或创建余额账户(加行级锁)
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
|
||||
Returns:
|
||||
余额账户
|
||||
"""
|
||||
balance = await self.get_by_user_id_for_update(user_id)
|
||||
if balance is None:
|
||||
balance = await self.create(user_id=user_id)
|
||||
# 重新获取并加锁
|
||||
await self.session.flush()
|
||||
balance = await self.get_by_user_id_for_update(user_id)
|
||||
return balance # type: ignore
|
||||
|
||||
async def update_balance_optimistic(
|
||||
self,
|
||||
balance: UserBalance,
|
||||
delta: int,
|
||||
*,
|
||||
is_recharge: bool = False,
|
||||
is_consumption: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
使用乐观锁更新余额
|
||||
|
||||
通过版本号检查确保并发安全。
|
||||
|
||||
Args:
|
||||
balance: 余额账户
|
||||
delta: 变化量(正数增加,负数减少)
|
||||
is_recharge: 是否为充值
|
||||
is_consumption: 是否为消费
|
||||
|
||||
Returns:
|
||||
是否更新成功
|
||||
"""
|
||||
current_version = balance.version
|
||||
new_balance = balance.balance + delta
|
||||
|
||||
# 构建更新语句
|
||||
update_values: dict[str, Any] = {
|
||||
"balance": new_balance,
|
||||
"version": current_version + 1,
|
||||
}
|
||||
|
||||
if is_recharge and delta > 0:
|
||||
update_values["total_recharged"] = balance.total_recharged + delta
|
||||
if is_consumption and delta < 0:
|
||||
update_values["total_consumed"] = balance.total_consumed + abs(delta)
|
||||
|
||||
stmt = (
|
||||
update(UserBalance)
|
||||
.where(
|
||||
and_(
|
||||
UserBalance.id == balance.id,
|
||||
UserBalance.version == current_version, # 乐观锁检查
|
||||
)
|
||||
)
|
||||
.values(**update_values)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
# 更新成功,刷新对象
|
||||
balance.balance = new_balance
|
||||
balance.version = current_version + 1
|
||||
if is_recharge and delta > 0:
|
||||
balance.total_recharged += delta
|
||||
if is_consumption and delta < 0:
|
||||
balance.total_consumed += abs(delta)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def freeze_balance(
|
||||
self,
|
||||
balance: UserBalance,
|
||||
amount: int,
|
||||
) -> bool:
|
||||
"""
|
||||
冻结余额
|
||||
|
||||
Args:
|
||||
balance: 余额账户
|
||||
amount: 冻结金额(正数)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if amount <= 0:
|
||||
return False
|
||||
if balance.available_balance < amount:
|
||||
return False
|
||||
|
||||
current_version = balance.version
|
||||
stmt = (
|
||||
update(UserBalance)
|
||||
.where(
|
||||
and_(
|
||||
UserBalance.id == balance.id,
|
||||
UserBalance.version == current_version,
|
||||
UserBalance.balance - UserBalance.frozen_balance >= amount,
|
||||
)
|
||||
)
|
||||
.values(
|
||||
frozen_balance=UserBalance.frozen_balance + amount,
|
||||
version=current_version + 1,
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
balance.frozen_balance += amount
|
||||
balance.version = current_version + 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def unfreeze_balance(
|
||||
self,
|
||||
balance: UserBalance,
|
||||
amount: int,
|
||||
) -> bool:
|
||||
"""
|
||||
解冻余额
|
||||
|
||||
Args:
|
||||
balance: 余额账户
|
||||
amount: 解冻金额(正数)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if amount <= 0:
|
||||
return False
|
||||
if balance.frozen_balance < amount:
|
||||
return False
|
||||
|
||||
current_version = balance.version
|
||||
stmt = (
|
||||
update(UserBalance)
|
||||
.where(
|
||||
and_(
|
||||
UserBalance.id == balance.id,
|
||||
UserBalance.version == current_version,
|
||||
UserBalance.frozen_balance >= amount,
|
||||
)
|
||||
)
|
||||
.values(
|
||||
frozen_balance=UserBalance.frozen_balance - amount,
|
||||
version=current_version + 1,
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
balance.frozen_balance -= amount
|
||||
balance.version = current_version + 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TransactionRepository(BaseRepository[BalanceTransaction]):
|
||||
"""交易记录仓库"""
|
||||
|
||||
model = BalanceTransaction
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
transaction_type: TransactionType | None = None,
|
||||
status: TransactionStatus | None = None,
|
||||
) -> list[BalanceTransaction]:
|
||||
"""
|
||||
获取用户的交易记录
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
transaction_type: 交易类型过滤
|
||||
status: 状态过滤
|
||||
|
||||
Returns:
|
||||
交易记录列表
|
||||
"""
|
||||
stmt = select(BalanceTransaction).where(
|
||||
BalanceTransaction.user_id == user_id
|
||||
)
|
||||
|
||||
if transaction_type:
|
||||
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
|
||||
if status:
|
||||
stmt = stmt.where(BalanceTransaction.status == status)
|
||||
|
||||
stmt = (
|
||||
stmt
|
||||
.order_by(BalanceTransaction.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
transaction_type: TransactionType | None = None,
|
||||
status: TransactionStatus | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计用户的交易记录数量
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
transaction_type: 交易类型过滤
|
||||
status: 状态过滤
|
||||
|
||||
Returns:
|
||||
记录数量
|
||||
"""
|
||||
from sqlalchemy import func
|
||||
|
||||
stmt = select(func.count()).select_from(BalanceTransaction).where(
|
||||
BalanceTransaction.user_id == user_id
|
||||
)
|
||||
|
||||
if transaction_type:
|
||||
stmt = stmt.where(BalanceTransaction.transaction_type == transaction_type)
|
||||
if status:
|
||||
stmt = stmt.where(BalanceTransaction.status == status)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_by_idempotency_key(
|
||||
self,
|
||||
idempotency_key: str,
|
||||
) -> BalanceTransaction | None:
|
||||
"""
|
||||
通过幂等键获取交易记录
|
||||
|
||||
用于防止重复提交。
|
||||
|
||||
Args:
|
||||
idempotency_key: 幂等键
|
||||
|
||||
Returns:
|
||||
交易记录或 None
|
||||
"""
|
||||
stmt = select(BalanceTransaction).where(
|
||||
BalanceTransaction.idempotency_key == idempotency_key
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_reference(
|
||||
self,
|
||||
reference_type: str,
|
||||
reference_id: str,
|
||||
) -> list[BalanceTransaction]:
|
||||
"""
|
||||
通过业务关联获取交易记录
|
||||
|
||||
Args:
|
||||
reference_type: 关联业务类型
|
||||
reference_id: 关联业务 ID
|
||||
|
||||
Returns:
|
||||
交易记录列表
|
||||
"""
|
||||
stmt = (
|
||||
select(BalanceTransaction)
|
||||
.where(
|
||||
and_(
|
||||
BalanceTransaction.reference_type == reference_type,
|
||||
BalanceTransaction.reference_id == reference_id,
|
||||
)
|
||||
)
|
||||
.order_by(BalanceTransaction.created_at.desc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
138
app/repositories/base.py
Normal file
138
app/repositories/base.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
基础仓库类
|
||||
|
||||
提供通用的 CRUD 操作封装。
|
||||
"""
|
||||
|
||||
from typing import Any, Generic, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import Base
|
||||
|
||||
ModelT = TypeVar("ModelT", bound=Base)
|
||||
|
||||
|
||||
class BaseRepository(Generic[ModelT]):
|
||||
"""
|
||||
基础仓库类
|
||||
|
||||
提供通用的数据库操作方法。
|
||||
"""
|
||||
|
||||
model: type[ModelT]
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化仓库
|
||||
|
||||
Args:
|
||||
session: 异步数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
|
||||
async def get_by_id(self, id: str) -> ModelT | None:
|
||||
"""
|
||||
通过 ID 获取实体
|
||||
|
||||
Args:
|
||||
id: 实体 ID
|
||||
|
||||
Returns:
|
||||
实体对象或 None
|
||||
"""
|
||||
return await self.session.get(self.model, id)
|
||||
|
||||
async def get_all(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> list[ModelT]:
|
||||
"""
|
||||
获取所有实体
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
stmt = select(self.model).offset(offset).limit(limit)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count(self) -> int:
|
||||
"""
|
||||
获取实体总数
|
||||
|
||||
Returns:
|
||||
实体数量
|
||||
"""
|
||||
stmt = select(func.count()).select_from(self.model)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def create(self, **kwargs: Any) -> ModelT:
|
||||
"""
|
||||
创建新实体
|
||||
|
||||
Args:
|
||||
**kwargs: 实体属性
|
||||
|
||||
Returns:
|
||||
新创建的实体
|
||||
"""
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = str(uuid4())
|
||||
|
||||
entity = self.model(**kwargs)
|
||||
self.session.add(entity)
|
||||
await self.session.flush()
|
||||
await self.session.refresh(entity)
|
||||
return entity
|
||||
|
||||
async def update(
|
||||
self,
|
||||
entity: ModelT,
|
||||
**kwargs: Any,
|
||||
) -> ModelT:
|
||||
"""
|
||||
更新实体
|
||||
|
||||
Args:
|
||||
entity: 要更新的实体
|
||||
**kwargs: 要更新的属性
|
||||
|
||||
Returns:
|
||||
更新后的实体
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(entity, key):
|
||||
setattr(entity, key, value)
|
||||
|
||||
await self.session.flush()
|
||||
await self.session.refresh(entity)
|
||||
return entity
|
||||
|
||||
async def delete(self, entity: ModelT) -> None:
|
||||
"""
|
||||
删除实体
|
||||
|
||||
Args:
|
||||
entity: 要删除的实体
|
||||
"""
|
||||
await self.session.delete(entity)
|
||||
await self.session.flush()
|
||||
|
||||
async def commit(self) -> None:
|
||||
"""提交事务"""
|
||||
await self.session.commit()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
"""回滚事务"""
|
||||
await self.session.rollback()
|
||||
|
||||
462
app/repositories/redeem_code.py
Normal file
462
app/repositories/redeem_code.py
Normal file
@@ -0,0 +1,462 @@
|
||||
"""
|
||||
兑换码仓库
|
||||
|
||||
处理兑换码相关的数据库操作。
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update, and_, or_, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.redeem_code import (
|
||||
RedeemCode,
|
||||
RedeemCodeBatch,
|
||||
RedeemCodeUsageLog,
|
||||
RedeemCodeStatus,
|
||||
)
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class RedeemCodeRepository(BaseRepository[RedeemCode]):
|
||||
"""兑换码仓库"""
|
||||
|
||||
model = RedeemCode
|
||||
|
||||
async def get_by_code(self, code: str) -> RedeemCode | None:
|
||||
"""
|
||||
通过兑换码获取记录
|
||||
|
||||
Args:
|
||||
code: 兑换码
|
||||
|
||||
Returns:
|
||||
兑换码记录或 None
|
||||
"""
|
||||
# 标准化兑换码格式
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
stmt = select(RedeemCode).where(RedeemCode.code == normalized_code)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_code_for_update(self, code: str) -> RedeemCode | None:
|
||||
"""
|
||||
通过兑换码获取记录(加行级锁)
|
||||
|
||||
用于兑换操作,防止并发兑换。
|
||||
|
||||
Args:
|
||||
code: 兑换码
|
||||
|
||||
Returns:
|
||||
兑换码记录或 None
|
||||
"""
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
stmt = (
|
||||
select(RedeemCode)
|
||||
.where(RedeemCode.code == normalized_code)
|
||||
.with_for_update()
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_valid_code(self, code: str) -> RedeemCode | None:
|
||||
"""
|
||||
获取有效的兑换码
|
||||
|
||||
检查状态、使用次数和有效期。
|
||||
|
||||
Args:
|
||||
code: 兑换码
|
||||
|
||||
Returns:
|
||||
有效的兑换码或 None
|
||||
"""
|
||||
normalized_code = code.strip().upper().replace(" ", "")
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
stmt = (
|
||||
select(RedeemCode)
|
||||
.where(
|
||||
and_(
|
||||
RedeemCode.code == normalized_code,
|
||||
RedeemCode.status == RedeemCodeStatus.ACTIVE,
|
||||
RedeemCode.used_count < RedeemCode.max_uses,
|
||||
or_(
|
||||
RedeemCode.expires_at.is_(None),
|
||||
RedeemCode.expires_at > now,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def mark_as_used(
|
||||
self,
|
||||
code: RedeemCode,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
标记兑换码已使用
|
||||
|
||||
更新使用计数和状态。
|
||||
|
||||
Args:
|
||||
code: 兑换码记录
|
||||
user_id: 使用者 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
new_used_count = code.used_count + 1
|
||||
new_status = (
|
||||
RedeemCodeStatus.USED
|
||||
if new_used_count >= code.max_uses
|
||||
else RedeemCodeStatus.ACTIVE
|
||||
)
|
||||
|
||||
stmt = (
|
||||
update(RedeemCode)
|
||||
.where(
|
||||
and_(
|
||||
RedeemCode.id == code.id,
|
||||
RedeemCode.used_count == code.used_count, # 乐观锁
|
||||
)
|
||||
)
|
||||
.values(
|
||||
used_count=new_used_count,
|
||||
status=new_status,
|
||||
used_by=user_id,
|
||||
used_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount == 1:
|
||||
code.used_count = new_used_count
|
||||
code.status = new_status
|
||||
code.used_by = user_id
|
||||
code.used_at = now
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def get_all_with_filters(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
batch_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> list[RedeemCode]:
|
||||
"""
|
||||
获取兑换码列表(支持过滤)
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
status: 状态过滤
|
||||
batch_id: 批次 ID 过滤
|
||||
code_like: 兑换码模糊匹配
|
||||
created_after: 创建时间起始
|
||||
created_before: 创建时间结束
|
||||
|
||||
Returns:
|
||||
兑换码列表
|
||||
"""
|
||||
stmt = select(RedeemCode)
|
||||
|
||||
conditions = []
|
||||
if status:
|
||||
conditions.append(RedeemCode.status == status)
|
||||
if batch_id:
|
||||
conditions.append(RedeemCode.batch_id == batch_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCode.code.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCode.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCode.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
stmt = (
|
||||
stmt
|
||||
.order_by(RedeemCode.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_with_filters(
|
||||
self,
|
||||
*,
|
||||
status: RedeemCodeStatus | None = None,
|
||||
batch_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计兑换码数量(支持过滤)
|
||||
"""
|
||||
stmt = select(func.count()).select_from(RedeemCode)
|
||||
|
||||
conditions = []
|
||||
if status:
|
||||
conditions.append(RedeemCode.status == status)
|
||||
if batch_id:
|
||||
conditions.append(RedeemCode.batch_id == batch_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCode.code.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCode.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCode.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def bulk_create(
|
||||
self,
|
||||
codes_data: list[dict[str, Any]],
|
||||
) -> list[RedeemCode]:
|
||||
"""
|
||||
批量创建兑换码
|
||||
|
||||
Args:
|
||||
codes_data: 兑换码数据列表
|
||||
|
||||
Returns:
|
||||
创建的兑换码列表
|
||||
"""
|
||||
codes = [RedeemCode(**data) for data in codes_data]
|
||||
self.session.add_all(codes)
|
||||
await self.session.flush()
|
||||
return codes
|
||||
|
||||
async def disable_code(self, code: RedeemCode) -> RedeemCode:
|
||||
"""
|
||||
禁用兑换码
|
||||
|
||||
Args:
|
||||
code: 兑换码记录
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
return await self.update(code, status=RedeemCodeStatus.DISABLED)
|
||||
|
||||
async def enable_code(self, code: RedeemCode) -> RedeemCode:
|
||||
"""
|
||||
启用兑换码
|
||||
|
||||
Args:
|
||||
code: 兑换码记录
|
||||
|
||||
Returns:
|
||||
更新后的兑换码
|
||||
"""
|
||||
# 只有禁用状态的可以重新启用
|
||||
if code.status != RedeemCodeStatus.DISABLED:
|
||||
return code
|
||||
|
||||
# 如果使用次数已满,改为已使用状态
|
||||
if code.used_count >= code.max_uses:
|
||||
return await self.update(code, status=RedeemCodeStatus.USED)
|
||||
|
||||
return await self.update(code, status=RedeemCodeStatus.ACTIVE)
|
||||
|
||||
|
||||
class RedeemCodeBatchRepository(BaseRepository[RedeemCodeBatch]):
|
||||
"""兑换码批次仓库"""
|
||||
|
||||
model = RedeemCodeBatch
|
||||
|
||||
async def get_all_batches(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> list[RedeemCodeBatch]:
|
||||
"""
|
||||
获取所有批次
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
批次列表
|
||||
"""
|
||||
stmt = (
|
||||
select(RedeemCodeBatch)
|
||||
.order_by(RedeemCodeBatch.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def increment_used_count(self, batch_id: str) -> None:
|
||||
"""
|
||||
增加批次已使用计数
|
||||
|
||||
Args:
|
||||
batch_id: 批次 ID
|
||||
"""
|
||||
stmt = (
|
||||
update(RedeemCodeBatch)
|
||||
.where(RedeemCodeBatch.id == batch_id)
|
||||
.values(used_count=RedeemCodeBatch.used_count + 1)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
|
||||
class RedeemCodeUsageLogRepository(BaseRepository[RedeemCodeUsageLog]):
|
||||
"""兑换码使用日志仓库"""
|
||||
|
||||
model = RedeemCodeUsageLog
|
||||
|
||||
async def get_by_code_id(
|
||||
self,
|
||||
redeem_code_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> list[RedeemCodeUsageLog]:
|
||||
"""
|
||||
获取兑换码的使用日志
|
||||
|
||||
Args:
|
||||
redeem_code_id: 兑换码 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
使用日志列表
|
||||
"""
|
||||
stmt = (
|
||||
select(RedeemCodeUsageLog)
|
||||
.where(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
||||
.order_by(RedeemCodeUsageLog.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
) -> list[RedeemCodeUsageLog]:
|
||||
"""
|
||||
获取用户的兑换日志
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
使用日志列表
|
||||
"""
|
||||
stmt = (
|
||||
select(RedeemCodeUsageLog)
|
||||
.where(RedeemCodeUsageLog.user_id == user_id)
|
||||
.order_by(RedeemCodeUsageLog.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_all_with_filters(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 20,
|
||||
redeem_code_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> list[RedeemCodeUsageLog]:
|
||||
"""
|
||||
获取使用日志(支持过滤)
|
||||
"""
|
||||
stmt = select(RedeemCodeUsageLog)
|
||||
|
||||
conditions = []
|
||||
if redeem_code_id:
|
||||
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
||||
if user_id:
|
||||
conditions.append(RedeemCodeUsageLog.user_id == user_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
stmt = (
|
||||
stmt
|
||||
.order_by(RedeemCodeUsageLog.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_with_filters(
|
||||
self,
|
||||
*,
|
||||
redeem_code_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
code_like: str | None = None,
|
||||
created_after: datetime | None = None,
|
||||
created_before: datetime | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
统计使用日志数量
|
||||
"""
|
||||
stmt = select(func.count()).select_from(RedeemCodeUsageLog)
|
||||
|
||||
conditions = []
|
||||
if redeem_code_id:
|
||||
conditions.append(RedeemCodeUsageLog.redeem_code_id == redeem_code_id)
|
||||
if user_id:
|
||||
conditions.append(RedeemCodeUsageLog.user_id == user_id)
|
||||
if code_like:
|
||||
conditions.append(RedeemCodeUsageLog.code_snapshot.contains(code_like.upper()))
|
||||
if created_after:
|
||||
conditions.append(RedeemCodeUsageLog.created_at >= created_after)
|
||||
if created_before:
|
||||
conditions.append(RedeemCodeUsageLog.created_at <= created_before)
|
||||
|
||||
if conditions:
|
||||
stmt = stmt.where(and_(*conditions))
|
||||
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
141
app/repositories/user.py
Normal file
141
app/repositories/user.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
用户仓库
|
||||
|
||||
处理用户相关的数据库操作。
|
||||
"""
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from app.models.user import User
|
||||
from app.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User]):
|
||||
"""用户数据仓库"""
|
||||
|
||||
model = User
|
||||
|
||||
async def get_by_username(self, username: str) -> User | None:
|
||||
"""
|
||||
通过用户名获取用户
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
stmt = select(User).where(User.username == username.lower())
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_email(self, email: str) -> User | None:
|
||||
"""
|
||||
通过邮箱获取用户
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
stmt = select(User).where(User.email == email.lower())
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_username_or_email(self, identifier: str) -> User | None:
|
||||
"""
|
||||
通过用户名或邮箱获取用户
|
||||
|
||||
Args:
|
||||
identifier: 用户名或邮箱
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
identifier_lower = identifier.lower()
|
||||
stmt = select(User).where(
|
||||
or_(
|
||||
User.username == identifier_lower,
|
||||
User.email == identifier_lower,
|
||||
)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def exists_by_username(self, username: str) -> bool:
|
||||
"""
|
||||
检查用户名是否存在
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
user = await self.get_by_username(username)
|
||||
return user is not None
|
||||
|
||||
async def exists_by_email(self, email: str) -> bool:
|
||||
"""
|
||||
检查邮箱是否存在
|
||||
|
||||
Args:
|
||||
email: 邮箱地址
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
user = await self.get_by_email(email)
|
||||
return user is not None
|
||||
|
||||
async def get_by_oauth(
|
||||
self,
|
||||
provider: str,
|
||||
oauth_user_id: str,
|
||||
) -> User | None:
|
||||
"""
|
||||
通过 OAuth2 提供商和用户 ID 获取用户
|
||||
|
||||
Args:
|
||||
provider: OAuth2 提供商标识
|
||||
oauth_user_id: OAuth2 用户 ID
|
||||
|
||||
Returns:
|
||||
用户对象或 None
|
||||
"""
|
||||
stmt = select(User).where(
|
||||
User.oauth_provider == provider,
|
||||
User.oauth_user_id == oauth_user_id,
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_users(
|
||||
self,
|
||||
*,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> list[User]:
|
||||
"""
|
||||
获取活跃用户列表
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
|
||||
Returns:
|
||||
活跃用户列表
|
||||
"""
|
||||
stmt = (
|
||||
select(User)
|
||||
.where(User.is_active == True) # noqa: E712
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(User.created_at.desc())
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
Reference in New Issue
Block a user