73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
from dataclasses import dataclass
|
|
from typing import Iterable, Optional
|
|
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from .auth.jwt_utils import decode_token
|
|
from .db import get_session
|
|
from .models import RoleName, User
|
|
|
|
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
@dataclass
|
|
class AuthUser:
|
|
user: User
|
|
role_name: str
|
|
customer_id: Optional[int]
|
|
token: str
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> AuthUser:
|
|
if credentials is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
|
try:
|
|
payload = decode_token(credentials.credentials)
|
|
except ValueError:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
|
|
|
|
user = await session.scalar(
|
|
select(User)
|
|
.where(User.id == int(user_id))
|
|
.options(selectinload(User.role), selectinload(User.customer))
|
|
)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
|
if not user.is_active:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User disabled")
|
|
|
|
role_name = payload.get("role") or (user.role.name if user.role else "")
|
|
customer_id = (
|
|
payload.get("customer_id")
|
|
or user.customer_id
|
|
or (user.customer.id if user.customer is not None else None)
|
|
)
|
|
return AuthUser(user=user, role_name=role_name, customer_id=customer_id, token=credentials.credentials)
|
|
|
|
|
|
def require_roles(roles: Iterable[RoleName]):
|
|
allowed = {r.value for r in roles}
|
|
|
|
async def dependency(auth_user: AuthUser = Depends(get_current_user)) -> AuthUser:
|
|
if auth_user.role_name not in allowed:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient role")
|
|
return auth_user
|
|
|
|
return dependency
|
|
|
|
|
|
def require_admin(auth_user: AuthUser = Depends(require_roles([RoleName.ADMIN]))) -> AuthUser:
|
|
return auth_user
|