Files
SatoNano/app/services/auth.py
2026-01-06 23:49:23 +08:00

297 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
认证服务
处理用户认证相关的业务逻辑。
"""
from datetime import timedelta
import jwt
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.exceptions import (
InvalidCredentialsError,
PasswordValidationError,
TokenError,
TokenExpiredError,
UserDisabledError,
UserNotFoundError,
)
from app.core.security import (
create_access_token,
create_refresh_token,
decode_token,
hash_password,
password_needs_rehash,
verify_password,
)
from app.models.user import User
from app.repositories.user import UserRepository
from app.schemas.auth import PasswordChangeRequest, TokenResponse
from app.services.user import UserService
class AuthService:
"""认证服务"""
def __init__(self, session: AsyncSession):
"""
初始化认证服务
Args:
session: 数据库会话
"""
self.session = session
self.user_repo = UserRepository(session)
self.user_service = UserService(session)
async def authenticate(
self,
username: str,
password: str,
) -> User:
"""
验证用户凭证
Args:
username: 用户名或邮箱
password: 密码
Returns:
验证成功的用户对象
Raises:
InvalidCredentialsError: 凭证无效
UserDisabledError: 用户被禁用
"""
# 查找用户(支持用户名或邮箱登录)
user = await self.user_repo.get_by_username_or_email(username)
if not user:
# 防止时序攻击:即使用户不存在也进行密码验证
verify_password(password, "$argon2id$v=19$m=65536,t=3,p=4$dummy$dummy")
raise InvalidCredentialsError()
# 验证密码
if not verify_password(password, user.hashed_password):
raise InvalidCredentialsError()
# 检查用户状态
if not user.is_active:
raise UserDisabledError()
# 检查是否需要重新哈希密码(参数升级)
if password_needs_rehash(user.hashed_password):
await self.user_repo.update(
user,
hashed_password=hash_password(password),
)
await self.user_repo.commit()
return user
async def login(
self,
username: str,
password: str,
) -> tuple[User, TokenResponse]:
"""
用户登录
Args:
username: 用户名或邮箱
password: 密码
Returns:
(用户对象, 令牌响应)
"""
user = await self.authenticate(username, password)
# 更新最后登录时间
await self.user_service.update_last_login(user)
# 生成令牌
tokens = self._create_tokens(user)
return user, tokens
def _create_tokens(self, user: User) -> TokenResponse:
"""
为用户创建访问令牌和刷新令牌
Args:
user: 用户对象
Returns:
令牌响应
"""
access_token = create_access_token(
subject=user.id,
extra_claims={
"username": user.username,
"is_superuser": user.is_superuser,
},
)
refresh_token = create_refresh_token(subject=user.id)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
)
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
"""
刷新访问令牌
Args:
refresh_token: 刷新令牌
Returns:
新的令牌响应
Raises:
TokenError: 令牌无效
TokenExpiredError: 令牌已过期
"""
try:
payload = decode_token(refresh_token)
except jwt.ExpiredSignatureError:
raise TokenExpiredError()
except jwt.InvalidTokenError:
raise TokenError()
# 验证令牌类型
if payload.get("type") != "refresh":
raise TokenError("无效的令牌类型")
# 获取用户
user_id = payload.get("sub")
if not user_id:
raise TokenError()
user = await self.user_repo.get_by_id(user_id)
if not user:
raise UserNotFoundError(user_id)
if not user.is_active:
raise UserDisabledError()
# 生成新令牌
return self._create_tokens(user)
async def change_password(
self,
user_id: str,
password_data: PasswordChangeRequest,
) -> None:
"""
修改用户密码
Args:
user_id: 用户 ID
password_data: 密码修改数据
Raises:
UserNotFoundError: 用户不存在
InvalidCredentialsError: 当前密码错误
PasswordValidationError: 新密码不符合要求
"""
user = await self.user_repo.get_by_id(user_id)
if not user:
raise UserNotFoundError(user_id)
# 验证当前密码
if not verify_password(password_data.current_password, user.hashed_password):
raise InvalidCredentialsError("当前密码错误")
# 验证新密码强度
self._validate_password_strength(password_data.new_password)
# 更新密码
await self.user_repo.update(
user,
hashed_password=hash_password(password_data.new_password),
)
await self.user_repo.commit()
def _validate_password_strength(self, password: str) -> None:
"""
验证密码强度
Args:
password: 密码
Raises:
PasswordValidationError: 密码不符合要求
"""
import re
errors: list[str] = []
if len(password) < settings.password_min_length:
errors.append(f"密码长度不能少于 {settings.password_min_length}")
if len(password) > settings.password_max_length:
errors.append(f"密码长度不能超过 {settings.password_max_length}")
if settings.password_require_uppercase and not re.search(r"[A-Z]", password):
errors.append("至少包含一个大写字母")
if settings.password_require_lowercase and not re.search(r"[a-z]", password):
errors.append("至少包含一个小写字母")
if settings.password_require_digit and not re.search(r"\d", password):
errors.append("至少包含一个数字")
if settings.password_require_special and not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
errors.append("至少包含一个特殊字符")
if errors:
raise PasswordValidationError("".join(errors))
async def get_current_user(self, token: str) -> User:
"""
从令牌获取当前用户
Args:
token: 访问令牌
Returns:
用户对象
Raises:
TokenError: 令牌无效
TokenExpiredError: 令牌已过期
UserNotFoundError: 用户不存在
UserDisabledError: 用户被禁用
"""
try:
payload = decode_token(token)
except jwt.ExpiredSignatureError:
raise TokenExpiredError()
except jwt.InvalidTokenError:
raise TokenError()
# 验证令牌类型
if payload.get("type") != "access":
raise TokenError("无效的令牌类型")
user_id = payload.get("sub")
if not user_id:
raise TokenError()
user = await self.user_repo.get_by_id(user_id)
if not user:
raise UserNotFoundError(user_id)
if not user.is_active:
raise UserDisabledError()
return user