396 lines
12 KiB
Python
396 lines
12 KiB
Python
"""
|
||
OAuth2 认证服务
|
||
|
||
处理 OAuth2 认证流程,支持主备端点自动切换。
|
||
当首选端点不可达时,自动回退到备用端点。
|
||
"""
|
||
|
||
import logging
|
||
import secrets
|
||
from urllib.parse import urlencode
|
||
|
||
import httpx
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.core.config import settings
|
||
from app.core.exceptions import (
|
||
AuthenticationError,
|
||
ResourceConflictError,
|
||
)
|
||
from app.core.security import create_access_token, create_refresh_token
|
||
from app.models.user import User
|
||
from app.repositories.user import UserRepository
|
||
from app.schemas.auth import TokenResponse
|
||
from app.schemas.oauth2 import OAuth2TokenData, OAuth2UserInfo
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# OAuth2 提供商标识
|
||
OAUTH_PROVIDER_LINUXDO = "linuxdo"
|
||
|
||
|
||
class OAuth2EndpointError(AuthenticationError):
|
||
"""OAuth2 端点错误"""
|
||
|
||
def __init__(self, message: str = "OAuth2 服务不可用"):
|
||
super().__init__(message, "OAUTH2_ENDPOINT_ERROR")
|
||
|
||
|
||
class OAuth2StateError(AuthenticationError):
|
||
"""OAuth2 状态验证错误"""
|
||
|
||
def __init__(self, message: str = "无效的状态码"):
|
||
super().__init__(message, "OAUTH2_STATE_ERROR")
|
||
|
||
|
||
class OAuth2Service:
|
||
"""
|
||
OAuth2 认证服务
|
||
|
||
特性:
|
||
- 支持主备端点自动切换
|
||
- 首选端点请求失败时自动回退到备用端点
|
||
- 状态码验证防止 CSRF 攻击
|
||
"""
|
||
|
||
# 存储状态码(生产环境应使用 Redis)
|
||
_state_store: dict[str, bool] = {}
|
||
|
||
def __init__(self, session: AsyncSession):
|
||
"""
|
||
初始化 OAuth2 服务
|
||
|
||
Args:
|
||
session: 数据库会话
|
||
"""
|
||
self.session = session
|
||
self.user_repo = UserRepository(session)
|
||
|
||
# 端点配置
|
||
self._endpoints = {
|
||
"authorize": {
|
||
"primary": settings.oauth2_authorize_endpoint,
|
||
"reserve": settings.oauth2_authorize_endpoint_reserve,
|
||
},
|
||
"token": {
|
||
"primary": settings.oauth2_token_endpoint,
|
||
"reserve": settings.oauth2_token_endpoint_reserve,
|
||
},
|
||
"userinfo": {
|
||
"primary": settings.oauth2_user_info_endpoint,
|
||
"reserve": settings.oauth2_user_info_endpoint_reserve,
|
||
},
|
||
}
|
||
|
||
self._timeout = settings.oauth2_request_timeout
|
||
|
||
def generate_authorize_url(self, redirect_uri: str) -> tuple[str, str]:
|
||
"""
|
||
生成 OAuth2 授权 URL
|
||
|
||
Args:
|
||
redirect_uri: 回调 URL
|
||
|
||
Returns:
|
||
(授权 URL, 状态码)
|
||
"""
|
||
state = secrets.token_urlsafe(32)
|
||
self._state_store[state] = True # 存储状态码
|
||
|
||
params = {
|
||
"client_id": settings.oauth2_client_id,
|
||
"redirect_uri": redirect_uri,
|
||
"response_type": "code",
|
||
"state": state,
|
||
"scope": "read", # 根据实际需要调整
|
||
}
|
||
|
||
# 使用首选授权端点
|
||
authorize_url = f"{self._endpoints['authorize']['primary']}?{urlencode(params)}"
|
||
|
||
return authorize_url, state
|
||
|
||
def validate_state(self, state: str) -> bool:
|
||
"""
|
||
验证状态码(防 CSRF)
|
||
|
||
Args:
|
||
state: 状态码
|
||
|
||
Returns:
|
||
是否有效
|
||
"""
|
||
if state in self._state_store:
|
||
del self._state_store[state] # 使用后立即删除
|
||
return True
|
||
return False
|
||
|
||
async def _request_with_fallback(
|
||
self,
|
||
endpoint_type: str,
|
||
method: str,
|
||
**kwargs,
|
||
) -> httpx.Response:
|
||
"""
|
||
带回退的 HTTP 请求
|
||
|
||
首先尝试首选端点,失败后自动切换到备用端点。
|
||
|
||
Args:
|
||
endpoint_type: 端点类型(token/userinfo)
|
||
method: HTTP 方法
|
||
**kwargs: 请求参数
|
||
|
||
Returns:
|
||
响应对象
|
||
|
||
Raises:
|
||
OAuth2EndpointError: 所有端点都不可用
|
||
"""
|
||
endpoints = self._endpoints[endpoint_type]
|
||
last_error: Exception | None = None
|
||
|
||
for endpoint_name, url in [("primary", endpoints["primary"]), ("reserve", endpoints["reserve"])]:
|
||
try:
|
||
logger.debug(f"尝试 OAuth2 {endpoint_type} 端点 ({endpoint_name}): {url}")
|
||
|
||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||
if method.upper() == "POST":
|
||
response = await client.post(url, **kwargs)
|
||
else:
|
||
response = await client.get(url, **kwargs)
|
||
|
||
# 检查 HTTP 状态
|
||
if response.status_code >= 500:
|
||
logger.warning(
|
||
f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) "
|
||
f"返回服务器错误: {response.status_code}"
|
||
)
|
||
continue
|
||
|
||
logger.info(f"OAuth2 {endpoint_type} 请求成功 ({endpoint_name})")
|
||
return response
|
||
|
||
except httpx.TimeoutException as e:
|
||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 超时: {e}")
|
||
last_error = e
|
||
except httpx.ConnectError as e:
|
||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 连接失败: {e}")
|
||
last_error = e
|
||
except Exception as e:
|
||
logger.error(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 请求异常: {e}")
|
||
last_error = e
|
||
|
||
# 所有端点都失败
|
||
error_msg = f"OAuth2 {endpoint_type} 服务不可用"
|
||
if last_error:
|
||
error_msg += f": {last_error}"
|
||
raise OAuth2EndpointError(error_msg)
|
||
|
||
async def exchange_code_for_token(
|
||
self,
|
||
code: str,
|
||
redirect_uri: str,
|
||
) -> OAuth2TokenData:
|
||
"""
|
||
用授权码换取访问令牌
|
||
|
||
Args:
|
||
code: 授权码
|
||
redirect_uri: 回调 URL(必须与授权时一致)
|
||
|
||
Returns:
|
||
OAuth2 令牌数据
|
||
|
||
Raises:
|
||
OAuth2EndpointError: 端点不可用
|
||
AuthenticationError: 换取令牌失败
|
||
"""
|
||
data = {
|
||
"grant_type": "authorization_code",
|
||
"client_id": settings.oauth2_client_id,
|
||
"client_secret": settings.oauth2_client_secret,
|
||
"code": code,
|
||
"redirect_uri": redirect_uri,
|
||
}
|
||
|
||
response = await self._request_with_fallback(
|
||
"token",
|
||
"POST",
|
||
data=data,
|
||
headers={"Accept": "application/json"},
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"OAuth2 token 响应错误: {response.status_code} - {response.text}")
|
||
raise AuthenticationError(
|
||
f"获取访问令牌失败: {response.status_code}",
|
||
"OAUTH2_TOKEN_ERROR",
|
||
)
|
||
|
||
token_data = response.json()
|
||
return OAuth2TokenData(**token_data)
|
||
|
||
async def get_user_info(self, access_token: str) -> OAuth2UserInfo:
|
||
"""
|
||
获取 OAuth2 用户信息
|
||
|
||
Args:
|
||
access_token: OAuth2 访问令牌
|
||
|
||
Returns:
|
||
用户信息
|
||
|
||
Raises:
|
||
OAuth2EndpointError: 端点不可用
|
||
AuthenticationError: 获取用户信息失败
|
||
"""
|
||
response = await self._request_with_fallback(
|
||
"userinfo",
|
||
"GET",
|
||
headers={
|
||
"Authorization": f"Bearer {access_token}",
|
||
"Accept": "application/json",
|
||
},
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"OAuth2 userinfo 响应错误: {response.status_code} - {response.text}")
|
||
raise AuthenticationError(
|
||
f"获取用户信息失败: {response.status_code}",
|
||
"OAUTH2_USERINFO_ERROR",
|
||
)
|
||
|
||
user_data = response.json()
|
||
return OAuth2UserInfo(**user_data)
|
||
|
||
async def authenticate(
|
||
self,
|
||
code: str,
|
||
state: str,
|
||
redirect_uri: str,
|
||
) -> tuple[User, TokenResponse, bool]:
|
||
"""
|
||
完整的 OAuth2 认证流程
|
||
|
||
1. 验证状态码
|
||
2. 用授权码换取令牌
|
||
3. 获取用户信息
|
||
4. 创建或更新用户
|
||
5. 生成 JWT 令牌
|
||
|
||
Args:
|
||
code: 授权码
|
||
state: 状态码
|
||
redirect_uri: 回调 URL
|
||
|
||
Returns:
|
||
(用户对象, JWT 令牌响应, 是否新用户)
|
||
|
||
Raises:
|
||
OAuth2StateError: 状态码无效
|
||
OAuth2EndpointError: OAuth2 服务不可用
|
||
AuthenticationError: 认证失败
|
||
"""
|
||
# 1. 验证状态码
|
||
if not self.validate_state(state):
|
||
raise OAuth2StateError()
|
||
|
||
# 2. 换取令牌
|
||
oauth_token = await self.exchange_code_for_token(code, redirect_uri)
|
||
|
||
# 3. 获取用户信息
|
||
oauth_user = await self.get_user_info(oauth_token.access_token)
|
||
|
||
# 4. 查找或创建用户
|
||
user, is_new_user = await self._get_or_create_user(oauth_user)
|
||
|
||
# 5. 生成 JWT 令牌
|
||
tokens = self._create_tokens(user)
|
||
|
||
return user, tokens, is_new_user
|
||
|
||
async def _get_or_create_user(
|
||
self,
|
||
oauth_user: OAuth2UserInfo,
|
||
) -> tuple[User, bool]:
|
||
"""
|
||
根据 OAuth2 用户信息获取或创建本地用户
|
||
|
||
Args:
|
||
oauth_user: OAuth2 用户信息
|
||
|
||
Returns:
|
||
(用户对象, 是否新创建)
|
||
"""
|
||
oauth_user_id = str(oauth_user.id)
|
||
|
||
# 先通过 OAuth ID 查找
|
||
user = await self.user_repo.get_by_oauth(
|
||
provider=OAUTH_PROVIDER_LINUXDO,
|
||
oauth_user_id=oauth_user_id,
|
||
)
|
||
|
||
if user:
|
||
# 更新用户信息(头像等可能变化)
|
||
await self.user_repo.update(
|
||
user,
|
||
nickname=oauth_user.name or oauth_user.username,
|
||
avatar_url=oauth_user.avatar_url,
|
||
)
|
||
await self.user_repo.commit()
|
||
return user, False
|
||
|
||
# 检查用户名是否已存在
|
||
username = oauth_user.username.lower()
|
||
existing_user = await self.user_repo.get_by_username(username)
|
||
if existing_user:
|
||
# 用户名冲突,添加后缀
|
||
username = f"{username}_{oauth_user_id[:8]}"
|
||
|
||
# 创建新用户
|
||
user = await self.user_repo.create(
|
||
username=username,
|
||
email=oauth_user.email,
|
||
nickname=oauth_user.name or oauth_user.username,
|
||
avatar_url=oauth_user.avatar_url,
|
||
oauth_provider=OAUTH_PROVIDER_LINUXDO,
|
||
oauth_user_id=oauth_user_id,
|
||
hashed_password=None, # OAuth2 用户无密码
|
||
is_active=oauth_user.active,
|
||
)
|
||
|
||
await self.user_repo.commit()
|
||
logger.info(f"创建 OAuth2 用户: {user.username} (provider={OAUTH_PROVIDER_LINUXDO})")
|
||
|
||
return user, True
|
||
|
||
def _create_tokens(self, user: User) -> TokenResponse:
|
||
"""
|
||
为用户创建 JWT 令牌
|
||
|
||
Args:
|
||
user: 用户对象
|
||
|
||
Returns:
|
||
令牌响应
|
||
"""
|
||
access_token = create_access_token(
|
||
subject=user.id,
|
||
extra_claims={
|
||
"username": user.username,
|
||
"is_superuser": user.is_superuser,
|
||
"oauth_provider": user.oauth_provider,
|
||
},
|
||
)
|
||
|
||
refresh_token = create_refresh_token(subject=user.id)
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
refresh_token=refresh_token,
|
||
token_type="Bearer",
|
||
expires_in=settings.access_token_expire_minutes * 60,
|
||
)
|
||
|