# app/services/password_reset.py
import os
import hashlib
import secrets
from datetime import datetime, timedelta, timezone
from fastapi import HTTPException, Request
from asyncpg import Connection
from app.db.repositories.users import UsersRepository
from app.db.queries.queries import queries # aiosql 生成的 Queries 对象
from app.services import security # ✅ 使用项目原有 passlib 封装
from app.db.errors import EntityDoesNotExist # 用于兜底 try/except
# 业务常量
RESET_SCENE = "reset"
RESET_PURPOSE = "reset"
CODE_TTL_MINUTES = 30 # 验证码有效期(分钟)
# ===== 小工具 =====
def _sha256_hex(s: str) -> str:
return hashlib.sha256(s.encode("utf-8")).hexdigest()
def _email_html(code: str) -> str:
return f"""
重置你的密码
你的验证码({CODE_TTL_MINUTES} 分钟内有效):
{code}
若非本人操作请忽略此邮件。
"""
def _first_row(maybe_rows):
"""
aiosql + asyncpg 在 SELECT 时可能返回:
- asyncpg.Record
- list[Record]
- dict-like
统一取“第一条/单条”。
"""
if maybe_rows is None:
return None
if isinstance(maybe_rows, list):
return maybe_rows[0] if maybe_rows else None
return maybe_rows
def _get_key(row, key: str):
"""
兼容 asyncpg.Record / dict / tuple(list)
仅用于取 'id' 这种关键字段
"""
if row is None:
return None
# dict-like / Record
try:
if key in row:
return row[key] # type: ignore[index]
except Exception:
pass
# 某些驱动可能支持 .get
try:
return row.get(key) # type: ignore[attr-defined]
except Exception:
pass
# 最后尝试属性
return getattr(row, key, None)
async def _get_user_by_email_optional(users_repo: UsersRepository, *, email: str):
"""
安全获取用户:
- 若仓库实现了 get_user_by_email_optional,直接用
- 否则回退到 get_user_by_email,并用 try/except 屏蔽不存在异常
返回 UserInDB 或 None
"""
# 新接口:优先调用
if hasattr(users_repo, "get_user_by_email_optional"):
try:
return await users_repo.get_user_by_email_optional(email=email) # type: ignore[attr-defined]
except Exception:
return None
# 旧接口:try/except 防止抛出不存在
try:
return await users_repo.get_user_by_email(email=email)
except EntityDoesNotExist:
return None
# ===== 主流程 =====
async def send_reset_code_by_email(
request: Request,
conn: Connection,
users_repo: UsersRepository,
email: str,
) -> None:
"""
若邮箱存在:生成 6 位验证码 -> 只存哈希 -> 发送邮件(或开发阶段打印)
若不存在:静默返回,防止枚举邮箱
"""
user = await _get_user_by_email_optional(users_repo, email=email)
if not user:
return # 静默
# 6 位数字验证码(明文只用于发送/展示,数据库只存哈希)
code = f"{secrets.randbelow(1_000_000):06d}"
code_hash = _sha256_hex(code)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=CODE_TTL_MINUTES)
request_ip = request.client.host if request.client else None
user_agent = request.headers.get("user-agent", "")
await queries.create_email_code(
conn,
email=email,
scene=RESET_SCENE,
purpose=RESET_PURPOSE,
code_hash=code_hash,
expires_at=expires_at,
request_ip=request_ip,
user_agent=user_agent,
)
# === 发送邮件 ===
try:
# 如果你已有统一邮件服务,可直接调用;没有则打印在开发日志
from app.services.mailer import send_email # 可选
# 你的 send_email 若是异步函数,这里 await;若是同步也能正常抛异常被捕获
maybe_coro = send_email(
to_email=email,
subject="重置密码验证码",
html=_email_html(code),
)
if hasattr(maybe_coro, "__await__"):
await maybe_coro # 兼容 async 版本
except Exception:
print(f"[DEV] reset code for {email}: {code} (expires in {CODE_TTL_MINUTES}m)")
async def reset_password_with_code(
conn: Connection,
users_repo: UsersRepository,
*,
email: str,
code: str,
new_password: str,
) -> None:
"""
校验验证码 -> 修改密码 -> 标记验证码已使用 -> 清理历史
"""
code_hash = _sha256_hex(code.strip())
# 1) 校验验证码(只接受未使用且未过期)
rec = await queries.get_valid_email_code(
conn,
email=email,
scene=RESET_SCENE,
purpose=RESET_PURPOSE,
code_hash=code_hash,
)
rec = _first_row(rec)
if rec is None:
raise HTTPException(status_code=400, detail="验证码无效或已过期")
# 2) 查用户(安全获取,避免抛异常 & 防枚举)
user = await _get_user_by_email_optional(users_repo, email=email)
if user is None:
# 与验证码错误同样提示,避免暴露邮箱存在性
raise HTTPException(status_code=400, detail="验证码无效或已过期")
# 3) 生成新 salt / hash —— ✅ 使用项目原有 passlib 方案
# 关键点:和登录校验保持一致,对 (salt + plain_password) 做 passlib 哈希
new_salt = os.urandom(16).hex()
new_hashed = security.get_password_hash(new_salt + new_password)
# 4) 优先用 id 更新;若没有 id(历史坑),则回退用 email 更新
updated = None
try:
user_id = getattr(user, "id", None)
if user_id:
updated = await queries.update_user_password_by_id(
conn,
id=user_id,
new_salt=new_salt,
new_password=new_hashed, # ✅ passlib 生成的带前缀哈希
)
else:
updated = await queries.update_user_password_by_email(
conn,
email=email,
new_salt=new_salt,
new_password=new_hashed,
)
except Exception:
# 极端情况下,id 更新失败也再补 email 更新,确保不中断
updated = await queries.update_user_password_by_email(
conn,
email=email,
new_salt=new_salt,
new_password=new_hashed,
)
# aiosql 有时会返回 list,若是空 list 视为失败
if isinstance(updated, list) and not updated:
raise HTTPException(status_code=500, detail="密码更新失败")
# 5) 标记验证码已用 & 清理
rec_id = _get_key(rec, "id")
if rec_id is not None:
await queries.mark_email_code_used(conn, id=rec_id)
else:
print("[WARN] Could not resolve email_code.id to mark consumed.")
await queries.delete_expired_email_codes(conn)