""" OAuth2 认证服务 处理 OAuth2 认证流程,支持主备端点自动切换。 当首选端点不可达时,自动回退到备用端点。 """ import logging import secrets from urllib.parse import urlencode import httpx from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.core.exceptions import ( AuthenticationError, ResourceConflictError, ) from app.core.security import create_access_token, create_refresh_token from app.models.user import User from app.repositories.user import UserRepository from app.schemas.auth import TokenResponse from app.schemas.oauth2 import OAuth2TokenData, OAuth2UserInfo logger = logging.getLogger(__name__) # OAuth2 提供商标识 OAUTH_PROVIDER_LINUXDO = "linuxdo" class OAuth2EndpointError(AuthenticationError): """OAuth2 端点错误""" def __init__(self, message: str = "OAuth2 服务不可用"): super().__init__(message, "OAUTH2_ENDPOINT_ERROR") class OAuth2StateError(AuthenticationError): """OAuth2 状态验证错误""" def __init__(self, message: str = "无效的状态码"): super().__init__(message, "OAUTH2_STATE_ERROR") class OAuth2Service: """ OAuth2 认证服务 特性: - 支持主备端点自动切换 - 首选端点请求失败时自动回退到备用端点 - 状态码验证防止 CSRF 攻击 """ # 存储状态码(生产环境应使用 Redis) _state_store: dict[str, bool] = {} def __init__(self, session: AsyncSession): """ 初始化 OAuth2 服务 Args: session: 数据库会话 """ self.session = session self.user_repo = UserRepository(session) # 端点配置 self._endpoints = { "authorize": { "primary": settings.oauth2_authorize_endpoint, "reserve": settings.oauth2_authorize_endpoint_reserve, }, "token": { "primary": settings.oauth2_token_endpoint, "reserve": settings.oauth2_token_endpoint_reserve, }, "userinfo": { "primary": settings.oauth2_user_info_endpoint, "reserve": settings.oauth2_user_info_endpoint_reserve, }, } self._timeout = settings.oauth2_request_timeout def generate_authorize_url(self, redirect_uri: str) -> tuple[str, str]: """ 生成 OAuth2 授权 URL Args: redirect_uri: 回调 URL Returns: (授权 URL, 状态码) """ state = secrets.token_urlsafe(32) self._state_store[state] = True # 存储状态码 params = { "client_id": settings.oauth2_client_id, "redirect_uri": redirect_uri, "response_type": "code", "state": state, "scope": "read", # 根据实际需要调整 } # 使用首选授权端点 authorize_url = f"{self._endpoints['authorize']['primary']}?{urlencode(params)}" return authorize_url, state def validate_state(self, state: str) -> bool: """ 验证状态码(防 CSRF) Args: state: 状态码 Returns: 是否有效 """ if state in self._state_store: del self._state_store[state] # 使用后立即删除 return True return False async def _request_with_fallback( self, endpoint_type: str, method: str, **kwargs, ) -> httpx.Response: """ 带回退的 HTTP 请求 首先尝试首选端点,失败后自动切换到备用端点。 Args: endpoint_type: 端点类型(token/userinfo) method: HTTP 方法 **kwargs: 请求参数 Returns: 响应对象 Raises: OAuth2EndpointError: 所有端点都不可用 """ endpoints = self._endpoints[endpoint_type] last_error: Exception | None = None for endpoint_name, url in [("primary", endpoints["primary"]), ("reserve", endpoints["reserve"])]: try: logger.debug(f"尝试 OAuth2 {endpoint_type} 端点 ({endpoint_name}): {url}") async with httpx.AsyncClient(timeout=self._timeout) as client: if method.upper() == "POST": response = await client.post(url, **kwargs) else: response = await client.get(url, **kwargs) # 检查 HTTP 状态 if response.status_code >= 500: logger.warning( f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) " f"返回服务器错误: {response.status_code}" ) continue logger.info(f"OAuth2 {endpoint_type} 请求成功 ({endpoint_name})") return response except httpx.TimeoutException as e: logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 超时: {e}") last_error = e except httpx.ConnectError as e: logger.warning(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 连接失败: {e}") last_error = e except Exception as e: logger.error(f"OAuth2 {endpoint_type} 端点 ({endpoint_name}) 请求异常: {e}") last_error = e # 所有端点都失败 error_msg = f"OAuth2 {endpoint_type} 服务不可用" if last_error: error_msg += f": {last_error}" raise OAuth2EndpointError(error_msg) async def exchange_code_for_token( self, code: str, redirect_uri: str, ) -> OAuth2TokenData: """ 用授权码换取访问令牌 Args: code: 授权码 redirect_uri: 回调 URL(必须与授权时一致) Returns: OAuth2 令牌数据 Raises: OAuth2EndpointError: 端点不可用 AuthenticationError: 换取令牌失败 """ data = { "grant_type": "authorization_code", "client_id": settings.oauth2_client_id, "client_secret": settings.oauth2_client_secret, "code": code, "redirect_uri": redirect_uri, } response = await self._request_with_fallback( "token", "POST", data=data, headers={"Accept": "application/json"}, ) if response.status_code != 200: logger.error(f"OAuth2 token 响应错误: {response.status_code} - {response.text}") raise AuthenticationError( f"获取访问令牌失败: {response.status_code}", "OAUTH2_TOKEN_ERROR", ) token_data = response.json() return OAuth2TokenData(**token_data) async def get_user_info(self, access_token: str) -> OAuth2UserInfo: """ 获取 OAuth2 用户信息 Args: access_token: OAuth2 访问令牌 Returns: 用户信息 Raises: OAuth2EndpointError: 端点不可用 AuthenticationError: 获取用户信息失败 """ response = await self._request_with_fallback( "userinfo", "GET", headers={ "Authorization": f"Bearer {access_token}", "Accept": "application/json", }, ) if response.status_code != 200: logger.error(f"OAuth2 userinfo 响应错误: {response.status_code} - {response.text}") raise AuthenticationError( f"获取用户信息失败: {response.status_code}", "OAUTH2_USERINFO_ERROR", ) user_data = response.json() return OAuth2UserInfo(**user_data) async def authenticate( self, code: str, state: str, redirect_uri: str, ) -> tuple[User, TokenResponse, bool]: """ 完整的 OAuth2 认证流程 1. 验证状态码 2. 用授权码换取令牌 3. 获取用户信息 4. 创建或更新用户 5. 生成 JWT 令牌 Args: code: 授权码 state: 状态码 redirect_uri: 回调 URL Returns: (用户对象, JWT 令牌响应, 是否新用户) Raises: OAuth2StateError: 状态码无效 OAuth2EndpointError: OAuth2 服务不可用 AuthenticationError: 认证失败 """ # 1. 验证状态码 if not self.validate_state(state): raise OAuth2StateError() # 2. 换取令牌 oauth_token = await self.exchange_code_for_token(code, redirect_uri) # 3. 获取用户信息 oauth_user = await self.get_user_info(oauth_token.access_token) # 4. 查找或创建用户 user, is_new_user = await self._get_or_create_user(oauth_user) # 5. 生成 JWT 令牌 tokens = self._create_tokens(user) return user, tokens, is_new_user async def _get_or_create_user( self, oauth_user: OAuth2UserInfo, ) -> tuple[User, bool]: """ 根据 OAuth2 用户信息获取或创建本地用户 Args: oauth_user: OAuth2 用户信息 Returns: (用户对象, 是否新创建) """ oauth_user_id = str(oauth_user.id) # 先通过 OAuth ID 查找 user = await self.user_repo.get_by_oauth( provider=OAUTH_PROVIDER_LINUXDO, oauth_user_id=oauth_user_id, ) if user: # 更新用户信息(头像等可能变化) await self.user_repo.update( user, nickname=oauth_user.name or oauth_user.username, avatar_url=oauth_user.avatar_url, ) await self.user_repo.commit() return user, False # 检查用户名是否已存在 username = oauth_user.username.lower() existing_user = await self.user_repo.get_by_username(username) if existing_user: # 用户名冲突,添加后缀 username = f"{username}_{oauth_user_id[:8]}" # 创建新用户 user = await self.user_repo.create( username=username, email=oauth_user.email, nickname=oauth_user.name or oauth_user.username, avatar_url=oauth_user.avatar_url, oauth_provider=OAUTH_PROVIDER_LINUXDO, oauth_user_id=oauth_user_id, hashed_password=None, # OAuth2 用户无密码 is_active=oauth_user.active, ) await self.user_repo.commit() logger.info(f"创建 OAuth2 用户: {user.username} (provider={OAUTH_PROVIDER_LINUXDO})") return user, True def _create_tokens(self, user: User) -> TokenResponse: """ 为用户创建 JWT 令牌 Args: user: 用户对象 Returns: 令牌响应 """ access_token = create_access_token( subject=user.id, extra_claims={ "username": user.username, "is_superuser": user.is_superuser, "oauth_provider": user.oauth_provider, }, ) refresh_token = create_refresh_token(subject=user.id) return TokenResponse( access_token=access_token, refresh_token=refresh_token, token_type="Bearer", expires_in=settings.access_token_expire_minutes * 60, )