提供基本前后端骨架
This commit is contained in:
296
app/services/auth.py
Normal file
296
app/services/auth.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
认证服务
|
||||
|
||||
处理用户认证相关的业务逻辑。
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
import jwt
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
InvalidCredentialsError,
|
||||
PasswordValidationError,
|
||||
TokenError,
|
||||
TokenExpiredError,
|
||||
UserDisabledError,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from app.core.security import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
password_needs_rehash,
|
||||
verify_password,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.auth import PasswordChangeRequest, TokenResponse
|
||||
from app.services.user import UserService
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""认证服务"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化认证服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
self.user_service = UserService(session)
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> User:
|
||||
"""
|
||||
验证用户凭证
|
||||
|
||||
Args:
|
||||
username: 用户名或邮箱
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
验证成功的用户对象
|
||||
|
||||
Raises:
|
||||
InvalidCredentialsError: 凭证无效
|
||||
UserDisabledError: 用户被禁用
|
||||
"""
|
||||
# 查找用户(支持用户名或邮箱登录)
|
||||
user = await self.user_repo.get_by_username_or_email(username)
|
||||
|
||||
if not user:
|
||||
# 防止时序攻击:即使用户不存在也进行密码验证
|
||||
verify_password(password, "$argon2id$v=19$m=65536,t=3,p=4$dummy$dummy")
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(password, user.hashed_password):
|
||||
raise InvalidCredentialsError()
|
||||
|
||||
# 检查用户状态
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
# 检查是否需要重新哈希密码(参数升级)
|
||||
if password_needs_rehash(user.hashed_password):
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
hashed_password=hash_password(password),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
return user
|
||||
|
||||
async def login(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
) -> tuple[User, TokenResponse]:
|
||||
"""
|
||||
用户登录
|
||||
|
||||
Args:
|
||||
username: 用户名或邮箱
|
||||
password: 密码
|
||||
|
||||
Returns:
|
||||
(用户对象, 令牌响应)
|
||||
"""
|
||||
user = await self.authenticate(username, password)
|
||||
|
||||
# 更新最后登录时间
|
||||
await self.user_service.update_last_login(user)
|
||||
|
||||
# 生成令牌
|
||||
tokens = self._create_tokens(user)
|
||||
|
||||
return user, tokens
|
||||
|
||||
def _create_tokens(self, user: User) -> TokenResponse:
|
||||
"""
|
||||
为用户创建访问令牌和刷新令牌
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
|
||||
Returns:
|
||||
令牌响应
|
||||
"""
|
||||
access_token = create_access_token(
|
||||
subject=user.id,
|
||||
extra_claims={
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
},
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(subject=user.id)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
async def refresh_tokens(self, refresh_token: str) -> TokenResponse:
|
||||
"""
|
||||
刷新访问令牌
|
||||
|
||||
Args:
|
||||
refresh_token: 刷新令牌
|
||||
|
||||
Returns:
|
||||
新的令牌响应
|
||||
|
||||
Raises:
|
||||
TokenError: 令牌无效
|
||||
TokenExpiredError: 令牌已过期
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(refresh_token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError()
|
||||
except jwt.InvalidTokenError:
|
||||
raise TokenError()
|
||||
|
||||
# 验证令牌类型
|
||||
if payload.get("type") != "refresh":
|
||||
raise TokenError("无效的令牌类型")
|
||||
|
||||
# 获取用户
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise TokenError()
|
||||
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
# 生成新令牌
|
||||
return self._create_tokens(user)
|
||||
|
||||
async def change_password(
|
||||
self,
|
||||
user_id: str,
|
||||
password_data: PasswordChangeRequest,
|
||||
) -> None:
|
||||
"""
|
||||
修改用户密码
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
password_data: 密码修改数据
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: 用户不存在
|
||||
InvalidCredentialsError: 当前密码错误
|
||||
PasswordValidationError: 新密码不符合要求
|
||||
"""
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
# 验证当前密码
|
||||
if not verify_password(password_data.current_password, user.hashed_password):
|
||||
raise InvalidCredentialsError("当前密码错误")
|
||||
|
||||
# 验证新密码强度
|
||||
self._validate_password_strength(password_data.new_password)
|
||||
|
||||
# 更新密码
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
hashed_password=hash_password(password_data.new_password),
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
|
||||
def _validate_password_strength(self, password: str) -> None:
|
||||
"""
|
||||
验证密码强度
|
||||
|
||||
Args:
|
||||
password: 密码
|
||||
|
||||
Raises:
|
||||
PasswordValidationError: 密码不符合要求
|
||||
"""
|
||||
import re
|
||||
|
||||
errors: list[str] = []
|
||||
|
||||
if len(password) < settings.password_min_length:
|
||||
errors.append(f"密码长度不能少于 {settings.password_min_length} 位")
|
||||
|
||||
if len(password) > settings.password_max_length:
|
||||
errors.append(f"密码长度不能超过 {settings.password_max_length} 位")
|
||||
|
||||
if settings.password_require_uppercase and not re.search(r"[A-Z]", password):
|
||||
errors.append("至少包含一个大写字母")
|
||||
|
||||
if settings.password_require_lowercase and not re.search(r"[a-z]", password):
|
||||
errors.append("至少包含一个小写字母")
|
||||
|
||||
if settings.password_require_digit and not re.search(r"\d", password):
|
||||
errors.append("至少包含一个数字")
|
||||
|
||||
if settings.password_require_special and not re.search(r"[!@#$%^&*(),.?\":{}|<>]", password):
|
||||
errors.append("至少包含一个特殊字符")
|
||||
|
||||
if errors:
|
||||
raise PasswordValidationError(";".join(errors))
|
||||
|
||||
async def get_current_user(self, token: str) -> User:
|
||||
"""
|
||||
从令牌获取当前用户
|
||||
|
||||
Args:
|
||||
token: 访问令牌
|
||||
|
||||
Returns:
|
||||
用户对象
|
||||
|
||||
Raises:
|
||||
TokenError: 令牌无效
|
||||
TokenExpiredError: 令牌已过期
|
||||
UserNotFoundError: 用户不存在
|
||||
UserDisabledError: 用户被禁用
|
||||
"""
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise TokenExpiredError()
|
||||
except jwt.InvalidTokenError:
|
||||
raise TokenError()
|
||||
|
||||
# 验证令牌类型
|
||||
if payload.get("type") != "access":
|
||||
raise TokenError("无效的令牌类型")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if not user_id:
|
||||
raise TokenError()
|
||||
|
||||
user = await self.user_repo.get_by_id(user_id)
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
if not user.is_active:
|
||||
raise UserDisabledError()
|
||||
|
||||
return user
|
||||
|
||||
Reference in New Issue
Block a user