322 lines
11 KiB
Python
322 lines
11 KiB
Python
# app/api/routes/authentication.py
|
||
from __future__ import annotations
|
||
|
||
from typing import Optional, Any, TYPE_CHECKING
|
||
from datetime import datetime, timedelta
|
||
|
||
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
|
||
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED
|
||
|
||
from app.api.dependencies.database import get_repository
|
||
from app.core.config import get_app_settings
|
||
from app.core.settings.app import AppSettings
|
||
from app.db.errors import EntityDoesNotExist
|
||
from app.db.repositories.users import UsersRepository
|
||
|
||
# 条件导入:运行期可能没有 email_codes 仓库
|
||
try:
|
||
from app.db.repositories.email_codes import EmailCodesRepository # type: ignore
|
||
HAS_EMAIL_CODES_REPO = True
|
||
except Exception: # pragma: no cover
|
||
EmailCodesRepository = None # type: ignore
|
||
HAS_EMAIL_CODES_REPO = False
|
||
|
||
# 仅用于类型检查(让 Pylance/pyright 认识名字,但运行期不导入)
|
||
if TYPE_CHECKING: # pragma: no cover
|
||
from app.db.repositories.email_codes import EmailCodesRepository as _EmailCodesRepositoryT # noqa: F401
|
||
|
||
from app.models.schemas.users import (
|
||
UserInLogin,
|
||
UserInResponse,
|
||
UserWithToken,
|
||
RegisterWithEmailIn,
|
||
)
|
||
from app.models.schemas.email_code import EmailCodeSendIn, EmailCodeSendOut
|
||
from app.resources import strings
|
||
from app.services import jwt
|
||
from app.services.mailer import send_email
|
||
from app.services.authentication import (
|
||
check_email_is_taken,
|
||
assert_passwords_match,
|
||
make_unique_username,
|
||
)
|
||
|
||
router = APIRouter()
|
||
|
||
# ================= Cookie 工具(最小改造,无需新增文件) =================
|
||
REFRESH_COOKIE_NAME = "refresh_token"
|
||
|
||
def set_refresh_cookie(resp: Response, token: str, *, max_age_days: int = 30) -> None:
|
||
"""
|
||
仅通过 HttpOnly Cookie 下发 refresh。
|
||
- SameSite=Lax:避免跨站表单滥用
|
||
- Secure=True:生产环境建议始终为 True;如本地纯 HTTP 开发可按需改为 False
|
||
- Path 设为 /api/auth,缩小作用域
|
||
"""
|
||
resp.set_cookie(
|
||
key=REFRESH_COOKIE_NAME,
|
||
value=token,
|
||
max_age=max_age_days * 24 * 3600,
|
||
httponly=True,
|
||
secure=True, # 如需在本地 http 调试,可改为 False
|
||
samesite="lax",
|
||
path="/api/auth",
|
||
)
|
||
|
||
def clear_refresh_cookie(resp: Response) -> None:
|
||
resp.delete_cookie(
|
||
key=REFRESH_COOKIE_NAME,
|
||
path="/api/auth",
|
||
httponly=True,
|
||
secure=True,
|
||
samesite="lax",
|
||
)
|
||
|
||
# 为了兼容“可选的验证码仓库”,构造一个可交给 Depends 的工厂
|
||
def _provide_optional_email_codes_repo():
|
||
if HAS_EMAIL_CODES_REPO:
|
||
return get_repository(EmailCodesRepository) # type: ignore[name-defined]
|
||
|
||
async def _none():
|
||
return None
|
||
|
||
return _none
|
||
|
||
|
||
# ========= 发送邮箱验证码 =========
|
||
@router.post(
|
||
"/email-code",
|
||
response_model=EmailCodeSendOut,
|
||
name="auth:email-code",
|
||
)
|
||
async def send_email_code(
|
||
payload: EmailCodeSendIn = Body(...),
|
||
settings: AppSettings = Depends(get_app_settings),
|
||
email_codes_repo: Optional[Any] = Depends(_provide_optional_email_codes_repo()),
|
||
) -> EmailCodeSendOut:
|
||
"""
|
||
发送邮箱验证码并写入数据库(若仓库存在)。
|
||
"""
|
||
# 1) 生成验证码(6 位数字)
|
||
rnd = __import__("random").randint(0, 999999)
|
||
code = f"{rnd:06d}"
|
||
|
||
# 2) 过期时间
|
||
expires_at = datetime.utcnow() + timedelta(minutes=settings.email_code_expires_minutes)
|
||
|
||
# 3) 记录到数据库(可选)
|
||
if email_codes_repo is not None:
|
||
await email_codes_repo.create_code( # type: ignore[attr-defined]
|
||
email=payload.email,
|
||
code=code,
|
||
scene=payload.scene,
|
||
expires_at=expires_at,
|
||
)
|
||
|
||
# 4) 发邮件
|
||
subject = f"【AI平台】{payload.scene} 验证码:{code}"
|
||
html = f"""
|
||
<div style="font-family:Arial,Helvetica,sans-serif;font-size:14px;line-height:1.6">
|
||
<p>您好!</p>
|
||
<p>您正在进行 <b>{payload.scene}</b> 操作,本次验证码为:</p>
|
||
<p style="font-size:22px;font-weight:700;letter-spacing:2px">{code}</p>
|
||
<p>有效期:{settings.email_code_expires_minutes} 分钟;请勿泄露给他人。</p>
|
||
</div>
|
||
"""
|
||
send_email(payload.email, subject, html)
|
||
return EmailCodeSendOut(ok=True)
|
||
|
||
|
||
# ========= 登录 =========
|
||
@router.post(
|
||
"/login",
|
||
response_model=UserInResponse,
|
||
response_model_exclude_none=True,
|
||
name="auth:login",
|
||
)
|
||
async def login(
|
||
response: Response,
|
||
user_login: UserInLogin = Body(..., embed=True, alias="user"),
|
||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||
settings: AppSettings = Depends(get_app_settings),
|
||
) -> UserInResponse:
|
||
"""邮箱 + 密码登录(签发 Access & Set-Cookie Refresh)"""
|
||
wrong_login_error = HTTPException(
|
||
status_code=HTTP_400_BAD_REQUEST,
|
||
detail=strings.INCORRECT_LOGIN_INPUT,
|
||
)
|
||
|
||
try:
|
||
user = await users_repo.get_user_by_email(email=user_login.email)
|
||
except EntityDoesNotExist as existence_error:
|
||
raise wrong_login_error from existence_error
|
||
|
||
if not user.check_password(user_login.password):
|
||
raise wrong_login_error
|
||
|
||
secret = str(settings.secret_key.get_secret_value())
|
||
|
||
# Access(15m) + Refresh(30d)
|
||
access = jwt.create_access_token_for_user(user, secret)
|
||
refresh = jwt.create_refresh_token_for_user(user, secret)
|
||
|
||
# 仅通过 HttpOnly Cookie 下发 refresh
|
||
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
|
||
|
||
return UserInResponse(
|
||
user=UserWithToken(
|
||
username=user.username,
|
||
email=user.email,
|
||
bio=user.bio,
|
||
image=user.image,
|
||
token=access, # 仍然在 body 返回 access,保持前端兼容
|
||
email_verified=getattr(user, "email_verified", False),
|
||
roles=getattr(user, "roles", []),
|
||
),
|
||
)
|
||
|
||
|
||
# ========= 注册 =========
|
||
@router.post(
|
||
"",
|
||
status_code=HTTP_201_CREATED,
|
||
response_model=UserInResponse,
|
||
response_model_exclude_none=True,
|
||
name="auth:register",
|
||
)
|
||
async def register(
|
||
response: Response,
|
||
payload: RegisterWithEmailIn = Body(..., embed=True, alias="user"),
|
||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||
settings: AppSettings = Depends(get_app_settings),
|
||
email_codes_repo: Optional[Any] = Depends(_provide_optional_email_codes_repo()),
|
||
) -> UserInResponse:
|
||
"""
|
||
注册流程:
|
||
1) 校验两次密码一致
|
||
2) 校验邮箱未被占用
|
||
3) 校验验证码(若存在验证码仓库)
|
||
4) 生成唯一用户名
|
||
5) 创建用户
|
||
6) 如仓库提供 set_email_verified,则置为 True
|
||
7) 签发 Access & Set-Cookie Refresh
|
||
"""
|
||
# 1) 两次密码一致
|
||
try:
|
||
assert_passwords_match(payload.password, payload.confirm_password)
|
||
except ValueError:
|
||
raise HTTPException(
|
||
status_code=HTTP_400_BAD_REQUEST,
|
||
detail="Passwords do not match",
|
||
)
|
||
|
||
# 2) 邮箱是否占用
|
||
if await check_email_is_taken(users_repo, payload.email):
|
||
raise HTTPException(
|
||
status_code=HTTP_400_BAD_REQUEST,
|
||
detail=strings.EMAIL_TAKEN,
|
||
)
|
||
|
||
# 3) 校验验证码
|
||
if email_codes_repo is not None:
|
||
ok = await email_codes_repo.verify_and_consume( # type: ignore[attr-defined]
|
||
email=payload.email,
|
||
code=payload.code,
|
||
scene="register",
|
||
)
|
||
if not ok:
|
||
raise HTTPException(
|
||
status_code=HTTP_400_BAD_REQUEST,
|
||
detail="Invalid or expired verification code",
|
||
)
|
||
|
||
# 4) 生成唯一用户名
|
||
username = await make_unique_username(users_repo, payload.email)
|
||
|
||
# 5) 创建用户
|
||
user = await users_repo.create_user(
|
||
username=username,
|
||
email=payload.email,
|
||
password=payload.password,
|
||
)
|
||
|
||
# 6) 若仓库支持置已验证,则更新并回读
|
||
if hasattr(users_repo, "set_email_verified"):
|
||
try:
|
||
await users_repo.set_email_verified(email=payload.email, verified=True) # type: ignore[attr-defined]
|
||
user = await users_repo.get_user_by_email(email=payload.email)
|
||
except Exception:
|
||
pass # 不阻塞主流程
|
||
|
||
# 7) 签发 Access & Refresh(并下发 Cookie)
|
||
secret = str(settings.secret_key.get_secret_value())
|
||
access = jwt.create_access_token_for_user(user, secret)
|
||
refresh = jwt.create_refresh_token_for_user(user, secret)
|
||
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
|
||
|
||
return UserInResponse(
|
||
user=UserWithToken(
|
||
username=user.username,
|
||
email=user.email,
|
||
bio=user.bio,
|
||
image=user.image,
|
||
token=access,
|
||
email_verified=getattr(user, "email_verified", True),
|
||
roles=getattr(user, "roles", []),
|
||
),
|
||
)
|
||
|
||
|
||
# ========= 刷新 Access(仅 Cookie 取 refresh)=========
|
||
@router.post(
|
||
"/refresh",
|
||
name="auth:refresh",
|
||
)
|
||
async def refresh_access_token(
|
||
request: Request,
|
||
response: Response,
|
||
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
||
settings: AppSettings = Depends(get_app_settings),
|
||
) -> dict:
|
||
"""
|
||
从 HttpOnly Cookie 读取 refresh,校验后签发新的 access,并重置 refresh Cookie。
|
||
最小改造版本:refresh 不轮换(如需轮换/重放检测,请走“增表方案”)。
|
||
"""
|
||
refresh = request.cookies.get(REFRESH_COOKIE_NAME)
|
||
if not refresh:
|
||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Missing refresh token")
|
||
|
||
secret = str(settings.secret_key.get_secret_value())
|
||
try:
|
||
username = jwt.get_username_from_token(refresh, secret, expected_subject=jwt.JWT_SUBJECT_REFRESH)
|
||
except ValueError:
|
||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
|
||
|
||
# 取用户(优先按 username)
|
||
try:
|
||
# 大多数 RealWorld 模板都有该方法
|
||
user = await users_repo.get_user_by_username(username=username) # type: ignore[attr-defined]
|
||
except Exception:
|
||
# 若没有 get_user_by_username,则退回按 email 查
|
||
try:
|
||
user = await users_repo.get_user_by_email(email=username)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="User not found") from e
|
||
|
||
# 签发新 access;最小改造——同一个 refresh 继续使用(不轮换)
|
||
access = jwt.create_access_token_for_user(user, secret)
|
||
# 也可选择重置 refresh 的过期时间(同值覆盖),这里直接重设 Cookie:
|
||
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
|
||
|
||
return {"token": access, "expires_in": jwt.ACCESS_TOKEN_EXPIRE_MINUTES * 60}
|
||
|
||
|
||
# ========= 登出(清 Cookie;前端清本地 access)=========
|
||
@router.post(
|
||
"/logout",
|
||
name="auth:logout",
|
||
)
|
||
async def logout(response: Response) -> dict:
|
||
clear_refresh_cookie(response)
|
||
return {"ok": True}
|