提供基本前后端骨架
This commit is contained in:
395
app/services/oauth2.py
Normal file
395
app/services/oauth2.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""
|
||||
OAuth2 认证服务
|
||||
|
||||
处理 OAuth2 认证流程,支持主备端点自动切换。
|
||||
当首选端点不可达时,自动回退到备用端点。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import (
|
||||
AuthenticationError,
|
||||
ResourceConflictError,
|
||||
)
|
||||
from app.core.security import create_access_token, create_refresh_token
|
||||
from app.models.user import User
|
||||
from app.repositories.user import UserRepository
|
||||
from app.schemas.auth import TokenResponse
|
||||
from app.schemas.oauth2 import OAuth2TokenData, OAuth2UserInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth2 提供商标识
|
||||
OAUTH_PROVIDER_LINUXDO = "linuxdo"
|
||||
|
||||
|
||||
class OAuth2EndpointError(AuthenticationError):
|
||||
"""OAuth2 端点错误"""
|
||||
|
||||
def __init__(self, message: str = "OAuth2 服务不可用"):
|
||||
super().__init__(message, "OAUTH2_ENDPOINT_ERROR")
|
||||
|
||||
|
||||
class OAuth2StateError(AuthenticationError):
|
||||
"""OAuth2 状态验证错误"""
|
||||
|
||||
def __init__(self, message: str = "无效的状态码"):
|
||||
super().__init__(message, "OAUTH2_STATE_ERROR")
|
||||
|
||||
|
||||
class OAuth2Service:
|
||||
"""
|
||||
OAuth2 认证服务
|
||||
|
||||
特性:
|
||||
- 支持主备端点自动切换
|
||||
- 首选端点请求失败时自动回退到备用端点
|
||||
- 状态码验证防止 CSRF 攻击
|
||||
"""
|
||||
|
||||
# 存储状态码(生产环境应使用 Redis)
|
||||
_state_store: dict[str, bool] = {}
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
"""
|
||||
初始化 OAuth2 服务
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
"""
|
||||
self.session = session
|
||||
self.user_repo = UserRepository(session)
|
||||
|
||||
# 端点配置
|
||||
self._endpoints = {
|
||||
"authorize": {
|
||||
"primary": settings.oauth2_authorize_endpoint,
|
||||
"reserve": settings.oauth2_authorize_endpoint_reserve,
|
||||
},
|
||||
"token": {
|
||||
"primary": settings.oauth2_token_endpoint,
|
||||
"reserve": settings.oauth2_token_endpoint_reserve,
|
||||
},
|
||||
"userinfo": {
|
||||
"primary": settings.oauth2_user_info_endpoint,
|
||||
"reserve": settings.oauth2_user_info_endpoint_reserve,
|
||||
},
|
||||
}
|
||||
|
||||
self._timeout = settings.oauth2_request_timeout
|
||||
|
||||
def generate_authorize_url(self, redirect_uri: str) -> tuple[str, str]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
|
||||
Args:
|
||||
redirect_uri: 回调 URL
|
||||
|
||||
Returns:
|
||||
(授权 URL, 状态码)
|
||||
"""
|
||||
state = secrets.token_urlsafe(32)
|
||||
self._state_store[state] = True # 存储状态码
|
||||
|
||||
params = {
|
||||
"client_id": settings.oauth2_client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
"scope": "read", # 根据实际需要调整
|
||||
}
|
||||
|
||||
# 使用首选授权端点
|
||||
authorize_url = f"{self._endpoints['authorize']['primary']}?{urlencode(params)}"
|
||||
|
||||
return authorize_url, state
|
||||
|
||||
def validate_state(self, state: str) -> bool:
|
||||
"""
|
||||
验证状态码(防 CSRF)
|
||||
|
||||
Args:
|
||||
state: 状态码
|
||||
|
||||
Returns:
|
||||
是否有效
|
||||
"""
|
||||
if state in self._state_store:
|
||||
del self._state_store[state] # 使用后立即删除
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _request_with_fallback(
|
||||
self,
|
||||
endpoint_type: str,
|
||||
method: str,
|
||||
**kwargs,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
带回退的 HTTP 请求
|
||||
|
||||
首先尝试首选端点,失败后自动切换到备用端点。
|
||||
|
||||
Args:
|
||||
endpoint_type: 端点类型(token/userinfo)
|
||||
method: HTTP 方法
|
||||
**kwargs: 请求参数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 所有端点都不可用
|
||||
"""
|
||||
endpoints = self._endpoints[endpoint_type]
|
||||
last_error: Exception | None = None
|
||||
|
||||
for endpoint_name, url in [("primary", endpoints["primary"]), ("reserve", endpoints["reserve"])]:
|
||||
try:
|
||||
logger.debug(f"尝试 OAuth2 {endpoint_type} 端点 ({endpoint_name}): {url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
if method.upper() == "POST":
|
||||
response = await client.post(url, **kwargs)
|
||||
else:
|
||||
response = await client.get(url, **kwargs)
|
||||
|
||||
# 检查 HTTP 状态
|
||||
if response.status_code >= 500:
|
||||
logger.warning(
|
||||
f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) "
|
||||
f"返回服务器错误: {response.status_code}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"OAuth2 {endpoint_type} 请求成功 ({endpoint_name})")
|
||||
return response
|
||||
|
||||
except httpx.TimeoutException as e:
|
||||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 超时: {e}")
|
||||
last_error = e
|
||||
except httpx.ConnectError as e:
|
||||
logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 连接失败: {e}")
|
||||
last_error = e
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 请求异常: {e}")
|
||||
last_error = e
|
||||
|
||||
# 所有端点都失败
|
||||
error_msg = f"OAuth2 {endpoint_type} 服务不可用"
|
||||
if last_error:
|
||||
error_msg += f": {last_error}"
|
||||
raise OAuth2EndpointError(error_msg)
|
||||
|
||||
async def exchange_code_for_token(
|
||||
self,
|
||||
code: str,
|
||||
redirect_uri: str,
|
||||
) -> OAuth2TokenData:
|
||||
"""
|
||||
用授权码换取访问令牌
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
redirect_uri: 回调 URL(必须与授权时一致)
|
||||
|
||||
Returns:
|
||||
OAuth2 令牌数据
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 端点不可用
|
||||
AuthenticationError: 换取令牌失败
|
||||
"""
|
||||
data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": settings.oauth2_client_id,
|
||||
"client_secret": settings.oauth2_client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
|
||||
response = await self._request_with_fallback(
|
||||
"token",
|
||||
"POST",
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth2 token 响应错误: {response.status_code} - {response.text}")
|
||||
raise AuthenticationError(
|
||||
f"获取访问令牌失败: {response.status_code}",
|
||||
"OAUTH2_TOKEN_ERROR",
|
||||
)
|
||||
|
||||
token_data = response.json()
|
||||
return OAuth2TokenData(**token_data)
|
||||
|
||||
async def get_user_info(self, access_token: str) -> OAuth2UserInfo:
|
||||
"""
|
||||
获取 OAuth2 用户信息
|
||||
|
||||
Args:
|
||||
access_token: OAuth2 访问令牌
|
||||
|
||||
Returns:
|
||||
用户信息
|
||||
|
||||
Raises:
|
||||
OAuth2EndpointError: 端点不可用
|
||||
AuthenticationError: 获取用户信息失败
|
||||
"""
|
||||
response = await self._request_with_fallback(
|
||||
"userinfo",
|
||||
"GET",
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"OAuth2 userinfo 响应错误: {response.status_code} - {response.text}")
|
||||
raise AuthenticationError(
|
||||
f"获取用户信息失败: {response.status_code}",
|
||||
"OAUTH2_USERINFO_ERROR",
|
||||
)
|
||||
|
||||
user_data = response.json()
|
||||
return OAuth2UserInfo(**user_data)
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
code: str,
|
||||
state: str,
|
||||
redirect_uri: str,
|
||||
) -> tuple[User, TokenResponse, bool]:
|
||||
"""
|
||||
完整的 OAuth2 认证流程
|
||||
|
||||
1. 验证状态码
|
||||
2. 用授权码换取令牌
|
||||
3. 获取用户信息
|
||||
4. 创建或更新用户
|
||||
5. 生成 JWT 令牌
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
state: 状态码
|
||||
redirect_uri: 回调 URL
|
||||
|
||||
Returns:
|
||||
(用户对象, JWT 令牌响应, 是否新用户)
|
||||
|
||||
Raises:
|
||||
OAuth2StateError: 状态码无效
|
||||
OAuth2EndpointError: OAuth2 服务不可用
|
||||
AuthenticationError: 认证失败
|
||||
"""
|
||||
# 1. 验证状态码
|
||||
if not self.validate_state(state):
|
||||
raise OAuth2StateError()
|
||||
|
||||
# 2. 换取令牌
|
||||
oauth_token = await self.exchange_code_for_token(code, redirect_uri)
|
||||
|
||||
# 3. 获取用户信息
|
||||
oauth_user = await self.get_user_info(oauth_token.access_token)
|
||||
|
||||
# 4. 查找或创建用户
|
||||
user, is_new_user = await self._get_or_create_user(oauth_user)
|
||||
|
||||
# 5. 生成 JWT 令牌
|
||||
tokens = self._create_tokens(user)
|
||||
|
||||
return user, tokens, is_new_user
|
||||
|
||||
async def _get_or_create_user(
|
||||
self,
|
||||
oauth_user: OAuth2UserInfo,
|
||||
) -> tuple[User, bool]:
|
||||
"""
|
||||
根据 OAuth2 用户信息获取或创建本地用户
|
||||
|
||||
Args:
|
||||
oauth_user: OAuth2 用户信息
|
||||
|
||||
Returns:
|
||||
(用户对象, 是否新创建)
|
||||
"""
|
||||
oauth_user_id = str(oauth_user.id)
|
||||
|
||||
# 先通过 OAuth ID 查找
|
||||
user = await self.user_repo.get_by_oauth(
|
||||
provider=OAUTH_PROVIDER_LINUXDO,
|
||||
oauth_user_id=oauth_user_id,
|
||||
)
|
||||
|
||||
if user:
|
||||
# 更新用户信息(头像等可能变化)
|
||||
await self.user_repo.update(
|
||||
user,
|
||||
nickname=oauth_user.name or oauth_user.username,
|
||||
avatar_url=oauth_user.avatar_url,
|
||||
)
|
||||
await self.user_repo.commit()
|
||||
return user, False
|
||||
|
||||
# 检查用户名是否已存在
|
||||
username = oauth_user.username.lower()
|
||||
existing_user = await self.user_repo.get_by_username(username)
|
||||
if existing_user:
|
||||
# 用户名冲突,添加后缀
|
||||
username = f"{username}_{oauth_user_id[:8]}"
|
||||
|
||||
# 创建新用户
|
||||
user = await self.user_repo.create(
|
||||
username=username,
|
||||
email=oauth_user.email,
|
||||
nickname=oauth_user.name or oauth_user.username,
|
||||
avatar_url=oauth_user.avatar_url,
|
||||
oauth_provider=OAUTH_PROVIDER_LINUXDO,
|
||||
oauth_user_id=oauth_user_id,
|
||||
hashed_password=None, # OAuth2 用户无密码
|
||||
is_active=oauth_user.active,
|
||||
)
|
||||
|
||||
await self.user_repo.commit()
|
||||
logger.info(f"创建 OAuth2 用户: {user.username} (provider={OAUTH_PROVIDER_LINUXDO})")
|
||||
|
||||
return user, True
|
||||
|
||||
def _create_tokens(self, user: User) -> TokenResponse:
|
||||
"""
|
||||
为用户创建 JWT 令牌
|
||||
|
||||
Args:
|
||||
user: 用户对象
|
||||
|
||||
Returns:
|
||||
令牌响应
|
||||
"""
|
||||
access_token = create_access_token(
|
||||
subject=user.id,
|
||||
extra_claims={
|
||||
"username": user.username,
|
||||
"is_superuser": user.is_superuser,
|
||||
"oauth_provider": user.oauth_provider,
|
||||
},
|
||||
)
|
||||
|
||||
refresh_token = create_refresh_token(subject=user.id)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
token_type="Bearer",
|
||||
expires_in=settings.access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user