""" 用户仓库 处理用户相关的数据库操作。 """ from sqlalchemy import or_, select from app.models.user import User from app.repositories.base import BaseRepository class UserRepository(BaseRepository[User]): """用户数据仓库""" model = User async def get_by_username(self, username: str) -> User | None: """ 通过用户名获取用户 Args: username: 用户名 Returns: 用户对象或 None """ stmt = select(User).where(User.username == username.lower()) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_email(self, email: str) -> User | None: """ 通过邮箱获取用户 Args: email: 邮箱地址 Returns: 用户对象或 None """ stmt = select(User).where(User.email == email.lower()) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_username_or_email(self, identifier: str) -> User | None: """ 通过用户名或邮箱获取用户 Args: identifier: 用户名或邮箱 Returns: 用户对象或 None """ identifier_lower = identifier.lower() stmt = select(User).where( or_( User.username == identifier_lower, User.email == identifier_lower, ) ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def exists_by_username(self, username: str) -> bool: """ 检查用户名是否存在 Args: username: 用户名 Returns: 是否存在 """ user = await self.get_by_username(username) return user is not None async def exists_by_email(self, email: str) -> bool: """ 检查邮箱是否存在 Args: email: 邮箱地址 Returns: 是否存在 """ if not email: return False user = await self.get_by_email(email) return user is not None async def get_by_oauth( self, provider: str, oauth_user_id: str, ) -> User | None: """ 通过 OAuth2 提供商和用户 ID 获取用户 Args: provider: OAuth2 提供商标识 oauth_user_id: OAuth2 用户 ID Returns: 用户对象或 None """ stmt = select(User).where( User.oauth_provider == provider, User.oauth_user_id == oauth_user_id, ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_active_users( self, *, offset: int = 0, limit: int = 100, ) -> list[User]: """ 获取活跃用户列表 Args: offset: 偏移量 limit: 限制数量 Returns: 活跃用户列表 """ stmt = ( select(User) .where(User.is_active == True) # noqa: E712 .offset(offset) .limit(limit) .order_by(User.created_at.desc()) ) result = await self.session.execute(stmt) return list(result.scalars().all())