from __future__ import annotations import json from typing import Iterable, List, Optional from app.db.errors import EntityDoesNotExist from app.db.queries.queries import queries from app.db.repositories.base import BaseRepository from app.models.domain.roles import Role class RolesRepository(BaseRepository): def _convert_role_row(self, row) -> dict: permissions = row.get("permissions") if row else [] if isinstance(permissions, str): try: permissions = json.loads(permissions) except ValueError: permissions = [] permissions = permissions or [] return { **row, "permissions": permissions, } async def list_roles(self) -> List[Role]: rows = await queries.list_roles(self.connection) return [Role(**self._convert_role_row(row)) for row in rows] async def get_role_by_id(self, role_id: int) -> Role: row = await queries.get_role_by_id(self.connection, role_id=role_id) if not row: raise EntityDoesNotExist(f"role {role_id} does not exist") return Role(**self._convert_role_row(row)) async def get_role_by_name(self, *, name: str) -> Optional[Role]: row = await self.connection.fetchrow( """ SELECT id, name, description, permissions, created_at, updated_at FROM roles WHERE name = $1 """, name, ) if not row: return None return Role(**self._convert_role_row(dict(row))) async def create_role( self, *, name: str, description: Optional[str] = "", permissions: Optional[Iterable[str]] = None, ) -> Role: row = await queries.create_role( self.connection, name=name, description=description or "", permissions=list(permissions or []), ) return Role(**self._convert_role_row(row)) async def update_role( self, *, role_id: int, name: Optional[str] = None, description: Optional[str] = None, permissions: Optional[Iterable[str]] = None, ) -> Role: row = await queries.update_role( self.connection, role_id=role_id, name=name, description=description, permissions=list(permissions) if permissions is not None else None, ) if not row: raise EntityDoesNotExist(f"role {role_id} does not exist") return Role(**self._convert_role_row(row)) async def ensure_role( self, *, name: str, description: Optional[str] = "", permissions: Optional[Iterable[str]] = None, ) -> Role: existing = await self.get_role_by_name(name=name) if existing: return existing return await self.create_role( name=name, description=description or "", permissions=permissions or [], ) async def delete_role(self, *, role_id: int) -> None: await queries.delete_role(self.connection, role_id=role_id) async def get_roles_for_user(self, *, user_id: int) -> List[Role]: rows = await queries.get_roles_for_user(self.connection, user_id=user_id) return [Role(**self._convert_role_row(row)) for row in rows] async def get_role_names_for_user(self, *, user_id: int) -> List[str]: return [role.name for role in await self.get_roles_for_user(user_id=user_id)] async def assign_role_to_user(self, *, user_id: int, role_id: int) -> None: await queries.assign_role_to_user( self.connection, user_id=user_id, role_id=role_id, ) async def revoke_role_from_user(self, *, user_id: int, role_id: int) -> None: await queries.revoke_role_from_user( self.connection, user_id=user_id, role_id=role_id, ) async def set_roles_for_user(self, *, user_id: int, role_ids: Iterable[int]) -> None: role_ids = list(dict.fromkeys(role_ids)) async with self.connection.transaction(): await self.connection.execute( "DELETE FROM user_roles WHERE user_id = $1", user_id, ) for role_id in role_ids: await queries.assign_role_to_user( self.connection, user_id=user_id, role_id=role_id, ) async def user_has_role(self, *, user_id: int, role_name: str) -> bool: row = await queries.user_has_role( self.connection, user_id=user_id, role_name=role_name, ) return bool(row and row.get("has_role"))