142 lines
3.5 KiB
Python
142 lines
3.5 KiB
Python
"""
|
|
用户仓库
|
|
|
|
处理用户相关的数据库操作。
|
|
"""
|
|
|
|
from sqlalchemy import or_, select
|
|
|
|
from app.models.user import User
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class UserRepository(BaseRepository[User]):
|
|
"""用户数据仓库"""
|
|
|
|
model = User
|
|
|
|
async def get_by_username(self, username: str) -> User | None:
|
|
"""
|
|
通过用户名获取用户
|
|
|
|
Args:
|
|
username: 用户名
|
|
|
|
Returns:
|
|
用户对象或 None
|
|
"""
|
|
stmt = select(User).where(User.username == username.lower())
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_email(self, email: str) -> User | None:
|
|
"""
|
|
通过邮箱获取用户
|
|
|
|
Args:
|
|
email: 邮箱地址
|
|
|
|
Returns:
|
|
用户对象或 None
|
|
"""
|
|
stmt = select(User).where(User.email == email.lower())
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_username_or_email(self, identifier: str) -> User | None:
|
|
"""
|
|
通过用户名或邮箱获取用户
|
|
|
|
Args:
|
|
identifier: 用户名或邮箱
|
|
|
|
Returns:
|
|
用户对象或 None
|
|
"""
|
|
identifier_lower = identifier.lower()
|
|
stmt = select(User).where(
|
|
or_(
|
|
User.username == identifier_lower,
|
|
User.email == identifier_lower,
|
|
)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def exists_by_username(self, username: str) -> bool:
|
|
"""
|
|
检查用户名是否存在
|
|
|
|
Args:
|
|
username: 用户名
|
|
|
|
Returns:
|
|
是否存在
|
|
"""
|
|
user = await self.get_by_username(username)
|
|
return user is not None
|
|
|
|
async def exists_by_email(self, email: str) -> bool:
|
|
"""
|
|
检查邮箱是否存在
|
|
|
|
Args:
|
|
email: 邮箱地址
|
|
|
|
Returns:
|
|
是否存在
|
|
"""
|
|
if not email:
|
|
return False
|
|
user = await self.get_by_email(email)
|
|
return user is not None
|
|
|
|
async def get_by_oauth(
|
|
self,
|
|
provider: str,
|
|
oauth_user_id: str,
|
|
) -> User | None:
|
|
"""
|
|
通过 OAuth2 提供商和用户 ID 获取用户
|
|
|
|
Args:
|
|
provider: OAuth2 提供商标识
|
|
oauth_user_id: OAuth2 用户 ID
|
|
|
|
Returns:
|
|
用户对象或 None
|
|
"""
|
|
stmt = select(User).where(
|
|
User.oauth_provider == provider,
|
|
User.oauth_user_id == oauth_user_id,
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_active_users(
|
|
self,
|
|
*,
|
|
offset: int = 0,
|
|
limit: int = 100,
|
|
) -> list[User]:
|
|
"""
|
|
获取活跃用户列表
|
|
|
|
Args:
|
|
offset: 偏移量
|
|
limit: 限制数量
|
|
|
|
Returns:
|
|
活跃用户列表
|
|
"""
|
|
stmt = (
|
|
select(User)
|
|
.where(User.is_active == True) # noqa: E712
|
|
.offset(offset)
|
|
.limit(limit)
|
|
.order_by(User.created_at.desc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|