AI-News/backend/app/services/password_reset.py
2025-12-04 10:04:21 +08:00

219 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# app/services/password_reset.py
import os
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from fastapi import HTTPException, Request
from asyncpg import Connection
from app.db.repositories.users import UsersRepository
from app.db.queries.queries import queries # aiosql 生成的 Queries 对象
from app.services import security # ✅ 使用项目原有 passlib 封装
from app.db.errors import EntityDoesNotExist # 用于兜底 try/except
# 业务常量
RESET_SCENE = "reset"
RESET_PURPOSE = "reset"
CODE_TTL_MINUTES = 30 # 验证码有效期(分钟)
# ===== 小工具 =====
def _sha256_hex(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
def _email_html(code: str) -> str:
return f"""
<div style="font-family:system-ui,-apple-system,Segoe UI,Roboto">
<h2>重置你的密码</h2>
<p>你的验证码({CODE_TTL_MINUTES} 分钟内有效):</p>
<p style="font-size:22px;letter-spacing:2px;"><b>{code}</b></p>
<p style="color:#666">若非本人操作请忽略此邮件。</p>
</div>
"""
def _first_row(maybe_rows):
"""
aiosql + asyncpg 在 SELECT 时可能返回:
- asyncpg.Record
- list[Record]
- dict-like
统一取“第一条/单条”。
"""
if maybe_rows is None:
return None
if isinstance(maybe_rows, list):
return maybe_rows[0] if maybe_rows else None
return maybe_rows
def _get_key(row, key: str):
"""
兼容 asyncpg.Record / dict / tuple(list)
仅用于取 'id' 这种关键字段
"""
if row is None:
return None
# dict-like / Record
try:
if key in row:
return row[key] # type: ignore[index]
except Exception:
pass
# 某些驱动可能支持 .get
try:
return row.get(key) # type: ignore[attr-defined]
except Exception:
pass
# 最后尝试属性
return getattr(row, key, None)
async def _get_user_by_email_optional(users_repo: UsersRepository, *, email: str):
"""
安全获取用户:
- 若仓库实现了 get_user_by_email_optional直接用
- 否则回退到 get_user_by_email并用 try/except 屏蔽不存在异常
返回 UserInDB 或 None
"""
# 新接口:优先调用
if hasattr(users_repo, "get_user_by_email_optional"):
try:
return await users_repo.get_user_by_email_optional(email=email) # type: ignore[attr-defined]
except Exception:
return None
# 旧接口try/except 防止抛出不存在
try:
return await users_repo.get_user_by_email(email=email)
except EntityDoesNotExist:
return None
# ===== 主流程 =====
async def send_reset_code_by_email(
request: Request,
conn: Connection,
users_repo: UsersRepository,
email: str,
) -> None:
"""
若邮箱存在:生成 6 位验证码 -> 只存哈希 -> 发送邮件(或开发阶段打印)
若不存在:静默返回,防止枚举邮箱
"""
user = await _get_user_by_email_optional(users_repo, email=email)
if not user:
return # 静默
# 6 位数字验证码(明文只用于发送/展示,数据库只存哈希)
code = f"{secrets.randbelow(1_000_000):06d}"
code_hash = _sha256_hex(code)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=CODE_TTL_MINUTES)
request_ip = request.client.host if request.client else None
user_agent = request.headers.get("user-agent", "")
await queries.create_email_code(
conn,
email=email,
scene=RESET_SCENE,
purpose=RESET_PURPOSE,
code_hash=code_hash,
expires_at=expires_at,
request_ip=request_ip,
user_agent=user_agent,
)
# === 发送邮件 ===
try:
# 如果你已有统一邮件服务,可直接调用;没有则打印在开发日志
from app.services.mailer import send_email # 可选
# 你的 send_email 若是异步函数,这里 await若是同步也能正常抛异常被捕获
maybe_coro = send_email(
to_email=email,
subject="重置密码验证码",
html=_email_html(code),
)
if hasattr(maybe_coro, "__await__"):
await maybe_coro # 兼容 async 版本
except Exception:
print(f"[DEV] reset code for {email}: {code} (expires in {CODE_TTL_MINUTES}m)")
async def reset_password_with_code(
conn: Connection,
users_repo: UsersRepository,
*,
email: str,
code: str,
new_password: str,
) -> None:
"""
校验验证码 -> 修改密码 -> 标记验证码已使用 -> 清理历史
"""
code_hash = _sha256_hex(code.strip())
# 1) 校验验证码(只接受未使用且未过期)
rec = await queries.get_valid_email_code(
conn,
email=email,
scene=RESET_SCENE,
purpose=RESET_PURPOSE,
code_hash=code_hash,
)
rec = _first_row(rec)
if rec is None:
raise HTTPException(status_code=400, detail="验证码无效或已过期")
# 2) 查用户(安全获取,避免抛异常 & 防枚举)
user = await _get_user_by_email_optional(users_repo, email=email)
if user is None:
# 与验证码错误同样提示,避免暴露邮箱存在性
raise HTTPException(status_code=400, detail="验证码无效或已过期")
# 3) 生成新 salt / hash —— ✅ 使用项目原有 passlib 方案
# 关键点:和登录校验保持一致,对 (salt + plain_password) 做 passlib 哈希
new_salt = os.urandom(16).hex()
new_hashed = security.get_password_hash(new_salt + new_password)
# 4) 优先用 id 更新;若没有 id历史坑则回退用 email 更新
updated = None
try:
user_id = getattr(user, "id", None)
if user_id:
updated = await queries.update_user_password_by_id(
conn,
id=user_id,
new_salt=new_salt,
new_password=new_hashed, # ✅ passlib 生成的带前缀哈希
)
else:
updated = await queries.update_user_password_by_email(
conn,
email=email,
new_salt=new_salt,
new_password=new_hashed,
)
except Exception:
# 极端情况下id 更新失败也再补 email 更新,确保不中断
updated = await queries.update_user_password_by_email(
conn,
email=email,
new_salt=new_salt,
new_password=new_hashed,
)
# aiosql 有时会返回 list若是空 list 视为失败
if isinstance(updated, list) and not updated:
raise HTTPException(status_code=500, detail="密码更新失败")
# 5) 标记验证码已用 & 清理
rec_id = _get_key(rec, "id")
if rec_id is not None:
await queries.mark_email_code_used(conn, id=rec_id)
else:
print("[WARN] Could not resolve email_code.id to mark consumed.")
await queries.delete_expired_email_codes(conn)