AI-News/backend/app/api/dependencies/authentication.py
2025-12-04 10:04:21 +08:00

124 lines
3.9 KiB
Python

# noqa:WPS201
from typing import Callable, Optional
from fastapi import Depends, HTTPException, Security
from fastapi.security import APIKeyHeader
from starlette import requests, status
from starlette.exceptions import HTTPException as StarletteHTTPException
from app.api.dependencies.database import get_repository
from app.core.config import get_app_settings
from app.core.settings.app import AppSettings
from app.db.errors import EntityDoesNotExist
from app.db.repositories.users import UsersRepository
from app.models.domain.users import User
from app.resources import strings
from app.services import jwt
HEADER_KEY = "Authorization"
class RWAPIKeyHeader(APIKeyHeader):
async def __call__( # noqa: WPS610
self,
request: requests.Request,
) -> Optional[str]:
try:
return await super().__call__(request)
except StarletteHTTPException as original_auth_exc:
raise HTTPException(
status_code=original_auth_exc.status_code,
detail=strings.AUTHENTICATION_REQUIRED,
)
def get_current_user_authorizer(*, required: bool = True) -> Callable: # type: ignore
return _get_current_user if required else _get_current_user_optional
def _get_authorization_header_retriever(
*,
required: bool = True,
) -> Callable: # type: ignore
return _get_authorization_header if required else _get_authorization_header_optional
def _get_authorization_header(
api_key: str = Security(RWAPIKeyHeader(name=HEADER_KEY)),
settings: AppSettings = Depends(get_app_settings),
) -> str:
try:
token_prefix, token = api_key.split(" ")
except ValueError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=strings.WRONG_TOKEN_PREFIX,
)
if token_prefix != settings.jwt_token_prefix:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=strings.WRONG_TOKEN_PREFIX,
)
return token
def _get_authorization_header_optional(
authorization: Optional[str] = Security(
RWAPIKeyHeader(name=HEADER_KEY, auto_error=False),
),
settings: AppSettings = Depends(get_app_settings),
) -> str:
if authorization:
return _get_authorization_header(authorization, settings)
return ""
async def _get_current_user(
users_repo: UsersRepository = Depends(get_repository(UsersRepository)),
token: str = Depends(_get_authorization_header_retriever()),
settings: AppSettings = Depends(get_app_settings),
) -> User:
try:
username = jwt.get_username_from_token(
token,
str(settings.secret_key.get_secret_value()),
)
except ValueError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=strings.MALFORMED_PAYLOAD,
)
try:
user = await users_repo.get_user_by_username(username=username)
try:
from loguru import logger # local import to avoid global dependency if not installed
logger.info(
"[Auth] fetched user username={} id/id_={}/{} roles={}",
getattr(user, "username", None),
getattr(user, "id", None),
getattr(user, "id_", None),
getattr(user, "roles", None),
)
except Exception:
pass
return user
except EntityDoesNotExist:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=strings.MALFORMED_PAYLOAD,
)
async def _get_current_user_optional(
repo: UsersRepository = Depends(get_repository(UsersRepository)),
token: str = Depends(_get_authorization_header_retriever(required=False)),
settings: AppSettings = Depends(get_app_settings),
) -> Optional[User]:
if token:
return await _get_current_user(repo, token, settings)
return None