# noqa:WPS201 from typing import Callable, Optional from fastapi import Depends, HTTPException, Security from fastapi.security import APIKeyHeader from starlette import requests, status from starlette.exceptions import HTTPException as StarletteHTTPException from app.api.dependencies.database import get_repository from app.core.config import get_app_settings from app.core.settings.app import AppSettings from app.db.errors import EntityDoesNotExist from app.db.repositories.users import UsersRepository from app.models.domain.users import User from app.resources import strings from app.services import jwt HEADER_KEY = "Authorization" class RWAPIKeyHeader(APIKeyHeader): async def __call__( # noqa: WPS610 self, request: requests.Request, ) -> Optional[str]: try: return await super().__call__(request) except StarletteHTTPException as original_auth_exc: raise HTTPException( status_code=original_auth_exc.status_code, detail=strings.AUTHENTICATION_REQUIRED, ) def get_current_user_authorizer(*, required: bool = True) -> Callable: # type: ignore return _get_current_user if required else _get_current_user_optional def _get_authorization_header_retriever( *, required: bool = True, ) -> Callable: # type: ignore return _get_authorization_header if required else _get_authorization_header_optional def _get_authorization_header( api_key: str = Security(RWAPIKeyHeader(name=HEADER_KEY)), settings: AppSettings = Depends(get_app_settings), ) -> str: try: token_prefix, token = api_key.split(" ") except ValueError: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=strings.WRONG_TOKEN_PREFIX, ) if token_prefix != settings.jwt_token_prefix: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=strings.WRONG_TOKEN_PREFIX, ) return token def _get_authorization_header_optional( authorization: Optional[str] = Security( RWAPIKeyHeader(name=HEADER_KEY, auto_error=False), ), settings: AppSettings = Depends(get_app_settings), ) -> str: if authorization: return _get_authorization_header(authorization, settings) return "" async def _get_current_user( users_repo: UsersRepository = Depends(get_repository(UsersRepository)), token: str = Depends(_get_authorization_header_retriever()), settings: AppSettings = Depends(get_app_settings), ) -> User: try: username = jwt.get_username_from_token( token, str(settings.secret_key.get_secret_value()), ) except ValueError: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=strings.MALFORMED_PAYLOAD, ) try: user = await users_repo.get_user_by_username(username=username) try: from loguru import logger # local import to avoid global dependency if not installed logger.info( "[Auth] fetched user username={} id/id_={}/{} roles={}", getattr(user, "username", None), getattr(user, "id", None), getattr(user, "id_", None), getattr(user, "roles", None), ) except Exception: pass return user except EntityDoesNotExist: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=strings.MALFORMED_PAYLOAD, ) async def _get_current_user_optional( repo: UsersRepository = Depends(get_repository(UsersRepository)), token: str = Depends(_get_authorization_header_retriever(required=False)), settings: AppSettings = Depends(get_app_settings), ) -> Optional[User]: if token: return await _get_current_user(repo, token, settings) return None