# app/api/routes/authentication.py
from __future__ import annotations
from typing import Optional, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from fastapi import APIRouter, Body, Depends, HTTPException, Request, Response
from starlette.status import HTTP_201_CREATED, HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED
from app.api.dependencies.database import get_repository
from app.core.config import get_app_settings
from app.core.settings.app import AppSettings
from app.db.errors import EntityDoesNotExist
from app.db.repositories.users import UsersRepository
# 条件导入:运行期可能没有 email_codes 仓库
try:
from app.db.repositories.email_codes import EmailCodesRepository # type: ignore
HAS_EMAIL_CODES_REPO = True
except Exception: # pragma: no cover
EmailCodesRepository = None # type: ignore
HAS_EMAIL_CODES_REPO = False
# 仅用于类型检查(让 Pylance/pyright 认识名字,但运行期不导入)
if TYPE_CHECKING: # pragma: no cover
from app.db.repositories.email_codes import EmailCodesRepository as _EmailCodesRepositoryT # noqa: F401
from app.models.schemas.users import (
UserInLogin,
UserInResponse,
UserWithToken,
RegisterWithEmailIn,
)
from app.models.schemas.email_code import EmailCodeSendIn, EmailCodeSendOut
from app.resources import strings
from app.services import jwt
from app.services.mailer import send_email
from app.services.authentication import (
check_email_is_taken,
assert_passwords_match,
make_unique_username,
)
router = APIRouter()
# ================= Cookie 工具(最小改造,无需新增文件) =================
REFRESH_COOKIE_NAME = "refresh_token"
def set_refresh_cookie(resp: Response, token: str, *, max_age_days: int = 30) -> None:
"""
仅通过 HttpOnly Cookie 下发 refresh。
- SameSite=Lax:避免跨站表单滥用
- Secure=True:生产环境建议始终为 True;如本地纯 HTTP 开发可按需改为 False
- Path 设为 /api/auth,缩小作用域
"""
resp.set_cookie(
key=REFRESH_COOKIE_NAME,
value=token,
max_age=max_age_days * 24 * 3600,
httponly=True,
secure=True, # 如需在本地 http 调试,可改为 False
samesite="lax",
path="/api/auth",
)
def clear_refresh_cookie(resp: Response) -> None:
resp.delete_cookie(
key=REFRESH_COOKIE_NAME,
path="/api/auth",
httponly=True,
secure=True,
samesite="lax",
)
# 为了兼容“可选的验证码仓库”,构造一个可交给 Depends 的工厂
def _provide_optional_email_codes_repo():
if HAS_EMAIL_CODES_REPO:
return get_repository(EmailCodesRepository) # type: ignore[name-defined]
async def _none():
return None
return _none
# ========= 发送邮箱验证码 =========
@router.post(
"/email-code",
response_model=EmailCodeSendOut,
name="auth:email-code",
)
async def send_email_code(
payload: EmailCodeSendIn = Body(...),
settings: AppSettings = Depends(get_app_settings),
email_codes_repo: Optional[Any] = Depends(_provide_optional_email_codes_repo()),
) -> EmailCodeSendOut:
"""
发送邮箱验证码并写入数据库(若仓库存在)。
"""
# 1) 生成验证码(6 位数字)
rnd = __import__("random").randint(0, 999999)
code = f"{rnd:06d}"
# 2) 过期时间
expires_at = datetime.utcnow() + timedelta(minutes=settings.email_code_expires_minutes)
# 3) 记录到数据库(可选)
if email_codes_repo is not None:
await email_codes_repo.create_code( # type: ignore[attr-defined]
email=payload.email,
code=code,
scene=payload.scene,
expires_at=expires_at,
)
# 4) 发邮件
subject = f"【AI平台】{payload.scene} 验证码:{code}"
html = f"""
您好!
您正在进行 {payload.scene} 操作,本次验证码为:
{code}
有效期:{settings.email_code_expires_minutes} 分钟;请勿泄露给他人。
"""
send_email(payload.email, subject, html)
return EmailCodeSendOut(ok=True)
# ========= 登录 =========
@router.post(
"/login",
response_model=UserInResponse,
response_model_exclude_none=True,
name="auth:login",
)
async def login(
response: Response,
user_login: UserInLogin = Body(..., embed=True, alias="user"),
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
settings: AppSettings = Depends(get_app_settings),
) -> UserInResponse:
"""邮箱 + 密码登录(签发 Access & Set-Cookie Refresh)"""
wrong_login_error = HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=strings.INCORRECT_LOGIN_INPUT,
)
try:
user = await users_repo.get_user_by_email(email=user_login.email)
except EntityDoesNotExist as existence_error:
raise wrong_login_error from existence_error
if not user.check_password(user_login.password):
raise wrong_login_error
secret = str(settings.secret_key.get_secret_value())
# Access(15m) + Refresh(30d)
access = jwt.create_access_token_for_user(user, secret)
refresh = jwt.create_refresh_token_for_user(user, secret)
# 仅通过 HttpOnly Cookie 下发 refresh
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
return UserInResponse(
user=UserWithToken(
username=user.username,
email=user.email,
bio=user.bio,
image=user.image,
token=access, # 仍然在 body 返回 access,保持前端兼容
email_verified=getattr(user, "email_verified", False),
roles=getattr(user, "roles", []),
),
)
# ========= 注册 =========
@router.post(
"",
status_code=HTTP_201_CREATED,
response_model=UserInResponse,
response_model_exclude_none=True,
name="auth:register",
)
async def register(
response: Response,
payload: RegisterWithEmailIn = Body(..., embed=True, alias="user"),
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
settings: AppSettings = Depends(get_app_settings),
email_codes_repo: Optional[Any] = Depends(_provide_optional_email_codes_repo()),
) -> UserInResponse:
"""
注册流程:
1) 校验两次密码一致
2) 校验邮箱未被占用
3) 校验验证码(若存在验证码仓库)
4) 生成唯一用户名
5) 创建用户
6) 如仓库提供 set_email_verified,则置为 True
7) 签发 Access & Set-Cookie Refresh
"""
# 1) 两次密码一致
try:
assert_passwords_match(payload.password, payload.confirm_password)
except ValueError:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Passwords do not match",
)
# 2) 邮箱是否占用
if await check_email_is_taken(users_repo, payload.email):
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail=strings.EMAIL_TAKEN,
)
# 3) 校验验证码
if email_codes_repo is not None:
ok = await email_codes_repo.verify_and_consume( # type: ignore[attr-defined]
email=payload.email,
code=payload.code,
scene="register",
)
if not ok:
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
detail="Invalid or expired verification code",
)
# 4) 生成唯一用户名
username = await make_unique_username(users_repo, payload.email)
# 5) 创建用户
user = await users_repo.create_user(
username=username,
email=payload.email,
password=payload.password,
)
# 6) 若仓库支持置已验证,则更新并回读
if hasattr(users_repo, "set_email_verified"):
try:
await users_repo.set_email_verified(email=payload.email, verified=True) # type: ignore[attr-defined]
user = await users_repo.get_user_by_email(email=payload.email)
except Exception:
pass # 不阻塞主流程
# 7) 签发 Access & Refresh(并下发 Cookie)
secret = str(settings.secret_key.get_secret_value())
access = jwt.create_access_token_for_user(user, secret)
refresh = jwt.create_refresh_token_for_user(user, secret)
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
return UserInResponse(
user=UserWithToken(
username=user.username,
email=user.email,
bio=user.bio,
image=user.image,
token=access,
email_verified=getattr(user, "email_verified", True),
roles=getattr(user, "roles", []),
),
)
# ========= 刷新 Access(仅 Cookie 取 refresh)=========
@router.post(
"/refresh",
name="auth:refresh",
)
async def refresh_access_token(
request: Request,
response: Response,
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
settings: AppSettings = Depends(get_app_settings),
) -> dict:
"""
从 HttpOnly Cookie 读取 refresh,校验后签发新的 access,并重置 refresh Cookie。
最小改造版本:refresh 不轮换(如需轮换/重放检测,请走“增表方案”)。
"""
refresh = request.cookies.get(REFRESH_COOKIE_NAME)
if not refresh:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Missing refresh token")
secret = str(settings.secret_key.get_secret_value())
try:
username = jwt.get_username_from_token(refresh, secret, expected_subject=jwt.JWT_SUBJECT_REFRESH)
except ValueError:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
# 取用户(优先按 username)
try:
# 大多数 RealWorld 模板都有该方法
user = await users_repo.get_user_by_username(username=username) # type: ignore[attr-defined]
except Exception:
# 若没有 get_user_by_username,则退回按 email 查
try:
user = await users_repo.get_user_by_email(email=username)
except Exception as e:
raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="User not found") from e
# 签发新 access;最小改造——同一个 refresh 继续使用(不轮换)
access = jwt.create_access_token_for_user(user, secret)
# 也可选择重置 refresh 的过期时间(同值覆盖),这里直接重设 Cookie:
set_refresh_cookie(response, refresh, max_age_days=jwt.REFRESH_TOKEN_EXPIRE_DAYS)
return {"token": access, "expires_in": jwt.ACCESS_TOKEN_EXPIRE_MINUTES * 60}
# ========= 登出(清 Cookie;前端清本地 access)=========
@router.post(
"/logout",
name="auth:logout",
)
async def logout(response: Response) -> dict:
clear_refresh_cookie(response)
return {"ok": True}