Files
SatoNano/app/services/oauth2.py
2026-01-06 23:49:23 +08:00

396 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OAuth2 认证服务
处理 OAuth2 认证流程,支持主备端点自动切换。
当首选端点不可达时,自动回退到备用端点。
"""
import logging
import secrets
from urllib.parse import urlencode
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import settings
from app.core.exceptions import (
AuthenticationError,
ResourceConflictError,
)
from app.core.security import create_access_token, create_refresh_token
from app.models.user import User
from app.repositories.user import UserRepository
from app.schemas.auth import TokenResponse
from app.schemas.oauth2 import OAuth2TokenData, OAuth2UserInfo
logger = logging.getLogger(__name__)
# OAuth2 提供商标识
OAUTH_PROVIDER_LINUXDO = "linuxdo"
class OAuth2EndpointError(AuthenticationError):
"""OAuth2 端点错误"""
def __init__(self, message: str = "OAuth2 服务不可用"):
super().__init__(message, "OAUTH2_ENDPOINT_ERROR")
class OAuth2StateError(AuthenticationError):
"""OAuth2 状态验证错误"""
def __init__(self, message: str = "无效的状态码"):
super().__init__(message, "OAUTH2_STATE_ERROR")
class OAuth2Service:
"""
OAuth2 认证服务
特性:
- 支持主备端点自动切换
- 首选端点请求失败时自动回退到备用端点
- 状态码验证防止 CSRF 攻击
"""
# 存储状态码(生产环境应使用 Redis
_state_store: dict[str, bool] = {}
def __init__(self, session: AsyncSession):
"""
初始化 OAuth2 服务
Args:
session: 数据库会话
"""
self.session = session
self.user_repo = UserRepository(session)
# 端点配置
self._endpoints = {
"authorize": {
"primary": settings.oauth2_authorize_endpoint,
"reserve": settings.oauth2_authorize_endpoint_reserve,
},
"token": {
"primary": settings.oauth2_token_endpoint,
"reserve": settings.oauth2_token_endpoint_reserve,
},
"userinfo": {
"primary": settings.oauth2_user_info_endpoint,
"reserve": settings.oauth2_user_info_endpoint_reserve,
},
}
self._timeout = settings.oauth2_request_timeout
def generate_authorize_url(self, redirect_uri: str) -> tuple[str, str]:
"""
生成 OAuth2 授权 URL
Args:
redirect_uri: 回调 URL
Returns:
(授权 URL, 状态码)
"""
state = secrets.token_urlsafe(32)
self._state_store[state] = True # 存储状态码
params = {
"client_id": settings.oauth2_client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
"scope": "read", # 根据实际需要调整
}
# 使用首选授权端点
authorize_url = f"{self._endpoints['authorize']['primary']}?{urlencode(params)}"
return authorize_url, state
def validate_state(self, state: str) -> bool:
"""
验证状态码(防 CSRF
Args:
state: 状态码
Returns:
是否有效
"""
if state in self._state_store:
del self._state_store[state] # 使用后立即删除
return True
return False
async def _request_with_fallback(
self,
endpoint_type: str,
method: str,
**kwargs,
) -> httpx.Response:
"""
带回退的 HTTP 请求
首先尝试首选端点,失败后自动切换到备用端点。
Args:
endpoint_type: 端点类型token/userinfo
method: HTTP 方法
**kwargs: 请求参数
Returns:
响应对象
Raises:
OAuth2EndpointError: 所有端点都不可用
"""
endpoints = self._endpoints[endpoint_type]
last_error: Exception | None = None
for endpoint_name, url in [("primary", endpoints["primary"]), ("reserve", endpoints["reserve"])]:
try:
logger.debug(f"尝试 OAuth2 {endpoint_type} 端点 ({endpoint_name}): {url}")
async with httpx.AsyncClient(timeout=self._timeout) as client:
if method.upper() == "POST":
response = await client.post(url, **kwargs)
else:
response = await client.get(url, **kwargs)
# 检查 HTTP 状态
if response.status_code >= 500:
logger.warning(
f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) "
f"返回服务器错误: {response.status_code}"
)
continue
logger.info(f"OAuth2 {endpoint_type} 请求成功 ({endpoint_name})")
return response
except httpx.TimeoutException as e:
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 超时: {e}")
last_error = e
except httpx.ConnectError as e:
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 连接失败: {e}")
last_error = e
except Exception as e:
logger.error(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 请求异常: {e}")
last_error = e
# 所有端点都失败
error_msg = f"OAuth2 {endpoint_type} 服务不可用"
if last_error:
error_msg += f": {last_error}"
raise OAuth2EndpointError(error_msg)
async def exchange_code_for_token(
self,
code: str,
redirect_uri: str,
) -> OAuth2TokenData:
"""
用授权码换取访问令牌
Args:
code: 授权码
redirect_uri: 回调 URL必须与授权时一致
Returns:
OAuth2 令牌数据
Raises:
OAuth2EndpointError: 端点不可用
AuthenticationError: 换取令牌失败
"""
data = {
"grant_type": "authorization_code",
"client_id": settings.oauth2_client_id,
"client_secret": settings.oauth2_client_secret,
"code": code,
"redirect_uri": redirect_uri,
}
response = await self._request_with_fallback(
"token",
"POST",
data=data,
headers={"Accept": "application/json"},
)
if response.status_code != 200:
logger.error(f"OAuth2 token 响应错误: {response.status_code} - {response.text}")
raise AuthenticationError(
f"获取访问令牌失败: {response.status_code}",
"OAUTH2_TOKEN_ERROR",
)
token_data = response.json()
return OAuth2TokenData(**token_data)
async def get_user_info(self, access_token: str) -> OAuth2UserInfo:
"""
获取 OAuth2 用户信息
Args:
access_token: OAuth2 访问令牌
Returns:
用户信息
Raises:
OAuth2EndpointError: 端点不可用
AuthenticationError: 获取用户信息失败
"""
response = await self._request_with_fallback(
"userinfo",
"GET",
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
)
if response.status_code != 200:
logger.error(f"OAuth2 userinfo 响应错误: {response.status_code} - {response.text}")
raise AuthenticationError(
f"获取用户信息失败: {response.status_code}",
"OAUTH2_USERINFO_ERROR",
)
user_data = response.json()
return OAuth2UserInfo(**user_data)
async def authenticate(
self,
code: str,
state: str,
redirect_uri: str,
) -> tuple[User, TokenResponse, bool]:
"""
完整的 OAuth2 认证流程
1. 验证状态码
2. 用授权码换取令牌
3. 获取用户信息
4. 创建或更新用户
5. 生成 JWT 令牌
Args:
code: 授权码
state: 状态码
redirect_uri: 回调 URL
Returns:
(用户对象, JWT 令牌响应, 是否新用户)
Raises:
OAuth2StateError: 状态码无效
OAuth2EndpointError: OAuth2 服务不可用
AuthenticationError: 认证失败
"""
# 1. 验证状态码
if not self.validate_state(state):
raise OAuth2StateError()
# 2. 换取令牌
oauth_token = await self.exchange_code_for_token(code, redirect_uri)
# 3. 获取用户信息
oauth_user = await self.get_user_info(oauth_token.access_token)
# 4. 查找或创建用户
user, is_new_user = await self._get_or_create_user(oauth_user)
# 5. 生成 JWT 令牌
tokens = self._create_tokens(user)
return user, tokens, is_new_user
async def _get_or_create_user(
self,
oauth_user: OAuth2UserInfo,
) -> tuple[User, bool]:
"""
根据 OAuth2 用户信息获取或创建本地用户
Args:
oauth_user: OAuth2 用户信息
Returns:
(用户对象, 是否新创建)
"""
oauth_user_id = str(oauth_user.id)
# 先通过 OAuth ID 查找
user = await self.user_repo.get_by_oauth(
provider=OAUTH_PROVIDER_LINUXDO,
oauth_user_id=oauth_user_id,
)
if user:
# 更新用户信息(头像等可能变化)
await self.user_repo.update(
user,
nickname=oauth_user.name or oauth_user.username,
avatar_url=oauth_user.avatar_url,
)
await self.user_repo.commit()
return user, False
# 检查用户名是否已存在
username = oauth_user.username.lower()
existing_user = await self.user_repo.get_by_username(username)
if existing_user:
# 用户名冲突,添加后缀
username = f"{username}_{oauth_user_id[:8]}"
# 创建新用户
user = await self.user_repo.create(
username=username,
email=oauth_user.email,
nickname=oauth_user.name or oauth_user.username,
avatar_url=oauth_user.avatar_url,
oauth_provider=OAUTH_PROVIDER_LINUXDO,
oauth_user_id=oauth_user_id,
hashed_password=None, # OAuth2 用户无密码
is_active=oauth_user.active,
)
await self.user_repo.commit()
logger.info(f"创建 OAuth2 用户: {user.username} (provider={OAUTH_PROVIDER_LINUXDO})")
return user, True
def _create_tokens(self, user: User) -> TokenResponse:
"""
为用户创建 JWT 令牌
Args:
user: 用户对象
Returns:
令牌响应
"""
access_token = create_access_token(
subject=user.id,
extra_claims={
"username": user.username,
"is_superuser": user.is_superuser,
"oauth_provider": user.oauth_provider,
},
)
refresh_token = create_refresh_token(subject=user.id)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
token_type="Bearer",
expires_in=settings.access_token_expire_minutes * 60,
)