2025-12-04 10:04:21 +08:00

117 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime, timedelta
from typing import Dict, Optional, Literal
import jwt
from pydantic import ValidationError
from app.models.domain.users import User
from app.models.schemas.jwt import JWTMeta, JWTUser
# === 配置 ===
ALGORITHM = "HS256"
# 统一区分两类 token 的 subject
JWT_SUBJECT_ACCESS = "access"
JWT_SUBJECT_REFRESH = "refresh"
# 有效期(按你的新方案)
ACCESS_TOKEN_EXPIRE_MINUTES = 15 # 15 分钟
REFRESH_TOKEN_EXPIRE_DAYS = 30 # 30 天
def _create_jwt_token(
*,
jwt_content: Dict[str, str],
secret_key: str,
expires_delta: timedelta,
subject: Literal["access", "refresh"],
) -> str:
"""
生成 JWT在 payload 中注入 exp / sub并用指定算法签名。
jwt_content 通常来自 Pydantic 模型(例如 JWTUser(username=...)
"""
to_encode = jwt_content.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update(JWTMeta(exp=expire, sub=subject).dict())
return jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
# ========== Access Token给前端放到 Authorization 里用) ==========
def create_access_token_for_user(user: User, secret_key: str) -> str:
"""
签发 Access Token有效期 15 分钟sub=access
"""
return _create_jwt_token(
jwt_content=JWTUser(username=user.username).dict(),
secret_key=secret_key,
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
subject=JWT_SUBJECT_ACCESS,
)
# ========== Refresh Token仅通过 HttpOnly Cookie 下发/使用) ==========
def create_refresh_token_for_user(user: User, secret_key: str) -> str:
"""
签发 Refresh Token有效期 30 天sub=refresh
说明:最小改造版本使用 JWT 作为 refresh若要更安全可改为随机串并服务端存哈希。
"""
return _create_jwt_token(
jwt_content=JWTUser(username=user.username).dict(),
secret_key=secret_key,
expires_delta=timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS),
subject=JWT_SUBJECT_REFRESH,
)
# ========== 解码与校验工具 ==========
def _decode_token(token: str, secret_key: str) -> Dict:
"""
解码并返回原始 payload失败时抛 ValueError。
"""
try:
return jwt.decode(token, secret_key, algorithms=[ALGORITHM])
except jwt.PyJWTError as decode_error:
raise ValueError("unable to decode JWT token") from decode_error
def get_username_from_token(
token: str,
secret_key: str,
expected_subject: Literal["access", "refresh"] = JWT_SUBJECT_ACCESS,
) -> str:
"""
解析 token 并返回用户名;同时校验 sub 是否符合预期(默认 access
- 用于受保护接口expected_subject='access'
- 用于刷新流程expected_subject='refresh'
"""
try:
payload = _decode_token(token, secret_key)
# 主动校验 sub避免把 refresh 当成 access 用
sub = payload.get("sub")
if sub != expected_subject:
raise ValueError(f"invalid token subject: expected '{expected_subject}', got '{sub}'")
# 用 Pydantic 做字段校验/提取
return JWTUser(**payload).username
except ValidationError as validation_error:
raise ValueError("malformed payload in token") from validation_error
# ========== 兼容旧用法的别名(如果你项目其他地方直接调用了它) ==========
def create_jwt_token(
*,
jwt_content: Dict[str, str],
secret_key: str,
expires_delta: timedelta,
) -> str:
"""
兼容旧签发函数:默认当作 Access Token 使用sub=access
建议新代码直接使用 create_access_token_for_user / create_refresh_token_for_user。
"""
return _create_jwt_token(
jwt_content=jwt_content,
secret_key=secret_key,
expires_delta=expires_delta,
subject=JWT_SUBJECT_ACCESS,
)