Files
SatoNano/app/api/v1/endpoints/oauth2.py
2026-01-06 23:49:23 +08:00

158 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OAuth2 认证 API
提供 OAuth2 第三方登录功能。
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import DbSession
from app.core.config import settings
from app.core.exceptions import AuthenticationError
from app.schemas.base import APIResponse
from app.schemas.oauth2 import OAuth2AuthorizeResponse, OAuth2LoginResponse
from app.services.oauth2 import OAuth2Service, OAuth2StateError
router = APIRouter()
def get_oauth2_service(session: DbSession) -> OAuth2Service:
"""获取 OAuth2 服务实例"""
return OAuth2Service(session)
def _get_redirect_uri(request: Request) -> str:
"""
获取 OAuth2 回调 URL
根据请求动态构建完整的回调 URL
"""
# 优先使用 X-Forwarded 头(反向代理场景)
scheme = request.headers.get("X-Forwarded-Proto", request.url.scheme)
host = request.headers.get("X-Forwarded-Host", request.url.netloc)
return f"{scheme}://{host}{settings.oauth2_callback_path}"
@router.get(
"/authorize",
response_model=APIResponse[OAuth2AuthorizeResponse],
summary="获取 OAuth2 授权 URL",
description="获取第三方登录授权页面 URL",
)
async def get_authorize_url(
request: Request,
session: DbSession,
) -> APIResponse[OAuth2AuthorizeResponse]:
"""
获取 OAuth2 授权 URL
返回授权页面 URL 和状态码,客户端应重定向用户到该 URL。
流程:
1. 前端调用此接口获取授权 URL
2. 前端将用户重定向到授权 URL
3. 用户在第三方平台完成授权
4. 第三方平台重定向回回调 URL携带 code 和 state
5. 前端调用 /callback 接口完成登录
"""
# 检查 OAuth2 是否已配置
if not settings.oauth2_client_id or not settings.oauth2_client_secret:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="OAuth2 登录未配置",
)
oauth2_service = OAuth2Service(session)
redirect_uri = _get_redirect_uri(request)
authorize_url, state = oauth2_service.generate_authorize_url(redirect_uri)
return APIResponse.ok(
data=OAuth2AuthorizeResponse(
authorize_url=authorize_url,
state=state,
),
message="请重定向到授权 URL",
)
@router.get(
"/callback",
response_model=APIResponse[OAuth2LoginResponse],
summary="OAuth2 回调",
description="处理 OAuth2 授权回调,完成登录",
)
async def oauth2_callback(
request: Request,
code: Annotated[str, Query(description="授权码")],
state: Annotated[str, Query(description="状态码")],
session: DbSession,
) -> APIResponse[OAuth2LoginResponse]:
"""
OAuth2 回调接口
处理第三方平台的授权回调:
1. 验证 state 防止 CSRF 攻击
2. 用 code 换取访问令牌
3. 获取用户信息
4. 创建或关联本地用户
5. 返回 JWT 令牌
"""
oauth2_service = OAuth2Service(session)
redirect_uri = _get_redirect_uri(request)
try:
user, tokens, is_new_user = await oauth2_service.authenticate(
code=code,
state=state,
redirect_uri=redirect_uri,
)
return APIResponse.ok(
data=OAuth2LoginResponse(
access_token=tokens.access_token,
refresh_token=tokens.refresh_token,
token_type=tokens.token_type,
expires_in=tokens.expires_in,
is_new_user=is_new_user,
),
message="登录成功" if not is_new_user else "注册并登录成功",
)
except OAuth2StateError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=e.message,
)
except AuthenticationError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=e.message,
)
@router.post(
"/callback",
response_model=APIResponse[OAuth2LoginResponse],
summary="OAuth2 回调 (POST)",
description="处理 OAuth2 授权回调POST 方式)",
)
async def oauth2_callback_post(
request: Request,
code: Annotated[str, Query(description="授权码")],
state: Annotated[str, Query(description="状态码")],
session: DbSession,
) -> APIResponse[OAuth2LoginResponse]:
"""
OAuth2 回调接口POST 方式)
某些场景下前端可能使用 POST 方式调用回调。
逻辑与 GET 方式相同。
"""
return await oauth2_callback(request, code, state, session)