124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
# noqa:WPS201
|
|
from typing import Callable, Optional
|
|
|
|
from fastapi import Depends, HTTPException, Security
|
|
from fastapi.security import APIKeyHeader
|
|
from starlette import requests, status
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
|
|
from app.api.dependencies.database import get_repository
|
|
from app.core.config import get_app_settings
|
|
from app.core.settings.app import AppSettings
|
|
from app.db.errors import EntityDoesNotExist
|
|
from app.db.repositories.users import UsersRepository
|
|
from app.models.domain.users import User
|
|
from app.resources import strings
|
|
from app.services import jwt
|
|
|
|
HEADER_KEY = "Authorization"
|
|
|
|
|
|
class RWAPIKeyHeader(APIKeyHeader):
|
|
async def __call__( # noqa: WPS610
|
|
self,
|
|
request: requests.Request,
|
|
) -> Optional[str]:
|
|
try:
|
|
return await super().__call__(request)
|
|
except StarletteHTTPException as original_auth_exc:
|
|
raise HTTPException(
|
|
status_code=original_auth_exc.status_code,
|
|
detail=strings.AUTHENTICATION_REQUIRED,
|
|
)
|
|
|
|
|
|
def get_current_user_authorizer(*, required: bool = True) -> Callable: # type: ignore
|
|
return _get_current_user if required else _get_current_user_optional
|
|
|
|
|
|
def _get_authorization_header_retriever(
|
|
*,
|
|
required: bool = True,
|
|
) -> Callable: # type: ignore
|
|
return _get_authorization_header if required else _get_authorization_header_optional
|
|
|
|
|
|
def _get_authorization_header(
|
|
api_key: str = Security(RWAPIKeyHeader(name=HEADER_KEY)),
|
|
settings: AppSettings = Depends(get_app_settings),
|
|
) -> str:
|
|
try:
|
|
token_prefix, token = api_key.split(" ")
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=strings.WRONG_TOKEN_PREFIX,
|
|
)
|
|
if token_prefix != settings.jwt_token_prefix:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=strings.WRONG_TOKEN_PREFIX,
|
|
)
|
|
|
|
return token
|
|
|
|
|
|
def _get_authorization_header_optional(
|
|
authorization: Optional[str] = Security(
|
|
RWAPIKeyHeader(name=HEADER_KEY, auto_error=False),
|
|
),
|
|
settings: AppSettings = Depends(get_app_settings),
|
|
) -> str:
|
|
if authorization:
|
|
return _get_authorization_header(authorization, settings)
|
|
|
|
return ""
|
|
|
|
|
|
async def _get_current_user(
|
|
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
|
token: str = Depends(_get_authorization_header_retriever()),
|
|
settings: AppSettings = Depends(get_app_settings),
|
|
) -> User:
|
|
try:
|
|
username = jwt.get_username_from_token(
|
|
token,
|
|
str(settings.secret_key.get_secret_value()),
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=strings.MALFORMED_PAYLOAD,
|
|
)
|
|
|
|
try:
|
|
user = await users_repo.get_user_by_username(username=username)
|
|
try:
|
|
from loguru import logger # local import to avoid global dependency if not installed
|
|
logger.info(
|
|
"[Auth] fetched user username={} id/id_={}/{} roles={}",
|
|
getattr(user, "username", None),
|
|
getattr(user, "id", None),
|
|
getattr(user, "id_", None),
|
|
getattr(user, "roles", None),
|
|
)
|
|
except Exception:
|
|
pass
|
|
return user
|
|
except EntityDoesNotExist:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=strings.MALFORMED_PAYLOAD,
|
|
)
|
|
|
|
|
|
async def _get_current_user_optional(
|
|
repo: UsersRepository = Depends(get_repository(UsersRepository)),
|
|
token: str = Depends(_get_authorization_header_retriever(required=False)),
|
|
settings: AppSettings = Depends(get_app_settings),
|
|
) -> Optional[User]:
|
|
if token:
|
|
return await _get_current_user(repo, token, settings)
|
|
|
|
return None
|