158 lines
4.7 KiB
Python
158 lines
4.7 KiB
Python
"""
|
||
OAuth2 认证 API
|
||
|
||
提供 OAuth2 第三方登录功能。
|
||
"""
|
||
|
||
from typing import Annotated
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import DbSession
|
||
from app.core.config import settings
|
||
from app.core.exceptions import AuthenticationError
|
||
from app.schemas.base import APIResponse
|
||
from app.schemas.oauth2 import OAuth2AuthorizeResponse, OAuth2LoginResponse
|
||
from app.services.oauth2 import OAuth2Service, OAuth2StateError
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
def get_oauth2_service(session: DbSession) -> OAuth2Service:
|
||
"""获取 OAuth2 服务实例"""
|
||
return OAuth2Service(session)
|
||
|
||
|
||
def _get_redirect_uri(request: Request) -> str:
|
||
"""
|
||
获取 OAuth2 回调 URL
|
||
|
||
根据请求动态构建完整的回调 URL
|
||
"""
|
||
# 优先使用 X-Forwarded 头(反向代理场景)
|
||
scheme = request.headers.get("X-Forwarded-Proto", request.url.scheme)
|
||
host = request.headers.get("X-Forwarded-Host", request.url.netloc)
|
||
|
||
return f"{scheme}://{host}{settings.oauth2_callback_path}"
|
||
|
||
|
||
@router.get(
|
||
"/authorize",
|
||
response_model=APIResponse[OAuth2AuthorizeResponse],
|
||
summary="获取 OAuth2 授权 URL",
|
||
description="获取第三方登录授权页面 URL",
|
||
)
|
||
async def get_authorize_url(
|
||
request: Request,
|
||
session: DbSession,
|
||
) -> APIResponse[OAuth2AuthorizeResponse]:
|
||
"""
|
||
获取 OAuth2 授权 URL
|
||
|
||
返回授权页面 URL 和状态码,客户端应重定向用户到该 URL。
|
||
|
||
流程:
|
||
1. 前端调用此接口获取授权 URL
|
||
2. 前端将用户重定向到授权 URL
|
||
3. 用户在第三方平台完成授权
|
||
4. 第三方平台重定向回回调 URL,携带 code 和 state
|
||
5. 前端调用 /callback 接口完成登录
|
||
"""
|
||
# 检查 OAuth2 是否已配置
|
||
if not settings.oauth2_client_id or not settings.oauth2_client_secret:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||
detail="OAuth2 登录未配置",
|
||
)
|
||
|
||
oauth2_service = OAuth2Service(session)
|
||
redirect_uri = _get_redirect_uri(request)
|
||
|
||
authorize_url, state = oauth2_service.generate_authorize_url(redirect_uri)
|
||
|
||
return APIResponse.ok(
|
||
data=OAuth2AuthorizeResponse(
|
||
authorize_url=authorize_url,
|
||
state=state,
|
||
),
|
||
message="请重定向到授权 URL",
|
||
)
|
||
|
||
|
||
@router.get(
|
||
"/callback",
|
||
response_model=APIResponse[OAuth2LoginResponse],
|
||
summary="OAuth2 回调",
|
||
description="处理 OAuth2 授权回调,完成登录",
|
||
)
|
||
async def oauth2_callback(
|
||
request: Request,
|
||
code: Annotated[str, Query(description="授权码")],
|
||
state: Annotated[str, Query(description="状态码")],
|
||
session: DbSession,
|
||
) -> APIResponse[OAuth2LoginResponse]:
|
||
"""
|
||
OAuth2 回调接口
|
||
|
||
处理第三方平台的授权回调:
|
||
1. 验证 state 防止 CSRF 攻击
|
||
2. 用 code 换取访问令牌
|
||
3. 获取用户信息
|
||
4. 创建或关联本地用户
|
||
5. 返回 JWT 令牌
|
||
"""
|
||
oauth2_service = OAuth2Service(session)
|
||
redirect_uri = _get_redirect_uri(request)
|
||
|
||
try:
|
||
user, tokens, is_new_user = await oauth2_service.authenticate(
|
||
code=code,
|
||
state=state,
|
||
redirect_uri=redirect_uri,
|
||
)
|
||
|
||
return APIResponse.ok(
|
||
data=OAuth2LoginResponse(
|
||
access_token=tokens.access_token,
|
||
refresh_token=tokens.refresh_token,
|
||
token_type=tokens.token_type,
|
||
expires_in=tokens.expires_in,
|
||
is_new_user=is_new_user,
|
||
),
|
||
message="登录成功" if not is_new_user else "注册并登录成功",
|
||
)
|
||
|
||
except OAuth2StateError as e:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=e.message,
|
||
)
|
||
except AuthenticationError as e:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail=e.message,
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/callback",
|
||
response_model=APIResponse[OAuth2LoginResponse],
|
||
summary="OAuth2 回调 (POST)",
|
||
description="处理 OAuth2 授权回调(POST 方式)",
|
||
)
|
||
async def oauth2_callback_post(
|
||
request: Request,
|
||
code: Annotated[str, Query(description="授权码")],
|
||
state: Annotated[str, Query(description="状态码")],
|
||
session: DbSession,
|
||
) -> APIResponse[OAuth2LoginResponse]:
|
||
"""
|
||
OAuth2 回调接口(POST 方式)
|
||
|
||
某些场景下前端可能使用 POST 方式调用回调。
|
||
逻辑与 GET 方式相同。
|
||
"""
|
||
return await oauth2_callback(request, code, state, session)
|
||
|