62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
# app/db/repositories/password_reset.py
|
||
import hashlib
|
||
from typing import Optional, Dict, Any
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from asyncpg import Connection
|
||
from app.db.queries.queries import queries
|
||
|
||
|
||
class PasswordResetRepository:
|
||
def __init__(self, conn: Connection) -> None:
|
||
self.connection = conn
|
||
|
||
@staticmethod
|
||
def _hash(token: str) -> str:
|
||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||
|
||
async def create(
|
||
self,
|
||
*,
|
||
user_id: int,
|
||
token: str,
|
||
ttl_minutes: int,
|
||
request_ip: Optional[str],
|
||
user_agent: Optional[str],
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
创建一次性重置令牌(仅存 token 的哈希)。
|
||
返回数据库返回的行(dict/record 兼容为 Dict[str, Any])。
|
||
"""
|
||
return await queries.create_password_reset_token(
|
||
self.connection,
|
||
user_id=user_id,
|
||
token_hash=self._hash(token),
|
||
expires_at=datetime.now(timezone.utc) + timedelta(minutes=ttl_minutes),
|
||
request_ip=request_ip,
|
||
user_agent=user_agent,
|
||
)
|
||
|
||
async def get_valid(self, *, token: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据明文 token 查找并校验是否可用:
|
||
- 存在
|
||
- 未使用
|
||
- 未过期
|
||
返回行 dict;无效则返回 None。
|
||
"""
|
||
row = await queries.get_password_reset_token_by_hash(
|
||
self.connection, token_hash=self._hash(token)
|
||
)
|
||
if not row:
|
||
return None
|
||
if row["used_at"] is not None:
|
||
return None
|
||
if row["expires_at"] <= datetime.now(timezone.utc):
|
||
return None
|
||
return row
|
||
|
||
async def mark_used(self, *, token_id: int) -> None:
|
||
"""将重置令牌标记为已使用。"""
|
||
await queries.mark_password_reset_token_used(self.connection, id=token_id)
|