144 lines
4.7 KiB
Python
144 lines
4.7 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Iterable, List, Optional
|
|
|
|
from app.db.errors import EntityDoesNotExist
|
|
from app.db.queries.queries import queries
|
|
from app.db.repositories.base import BaseRepository
|
|
from app.models.domain.roles import Role
|
|
|
|
|
|
class RolesRepository(BaseRepository):
|
|
def _convert_role_row(self, row) -> dict:
|
|
permissions = row.get("permissions") if row else []
|
|
if isinstance(permissions, str):
|
|
try:
|
|
permissions = json.loads(permissions)
|
|
except ValueError:
|
|
permissions = []
|
|
permissions = permissions or []
|
|
return {
|
|
**row,
|
|
"permissions": permissions,
|
|
}
|
|
|
|
async def list_roles(self) -> List[Role]:
|
|
rows = await queries.list_roles(self.connection)
|
|
return [Role(**self._convert_role_row(row)) for row in rows]
|
|
|
|
async def get_role_by_id(self, role_id: int) -> Role:
|
|
row = await queries.get_role_by_id(self.connection, role_id=role_id)
|
|
if not row:
|
|
raise EntityDoesNotExist(f"role {role_id} does not exist")
|
|
return Role(**self._convert_role_row(row))
|
|
|
|
async def get_role_by_name(self, *, name: str) -> Optional[Role]:
|
|
row = await self.connection.fetchrow(
|
|
"""
|
|
SELECT id, name, description, permissions, created_at, updated_at
|
|
FROM roles
|
|
WHERE name = $1
|
|
""",
|
|
name,
|
|
)
|
|
if not row:
|
|
return None
|
|
return Role(**self._convert_role_row(dict(row)))
|
|
|
|
async def create_role(
|
|
self,
|
|
*,
|
|
name: str,
|
|
description: Optional[str] = "",
|
|
permissions: Optional[Iterable[str]] = None,
|
|
) -> Role:
|
|
row = await queries.create_role(
|
|
self.connection,
|
|
name=name,
|
|
description=description or "",
|
|
permissions=list(permissions or []),
|
|
)
|
|
return Role(**self._convert_role_row(row))
|
|
|
|
async def update_role(
|
|
self,
|
|
*,
|
|
role_id: int,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
permissions: Optional[Iterable[str]] = None,
|
|
) -> Role:
|
|
row = await queries.update_role(
|
|
self.connection,
|
|
role_id=role_id,
|
|
name=name,
|
|
description=description,
|
|
permissions=list(permissions) if permissions is not None else None,
|
|
)
|
|
if not row:
|
|
raise EntityDoesNotExist(f"role {role_id} does not exist")
|
|
return Role(**self._convert_role_row(row))
|
|
|
|
async def ensure_role(
|
|
self,
|
|
*,
|
|
name: str,
|
|
description: Optional[str] = "",
|
|
permissions: Optional[Iterable[str]] = None,
|
|
) -> Role:
|
|
existing = await self.get_role_by_name(name=name)
|
|
if existing:
|
|
return existing
|
|
return await self.create_role(
|
|
name=name,
|
|
description=description or "",
|
|
permissions=permissions or [],
|
|
)
|
|
|
|
async def delete_role(self, *, role_id: int) -> None:
|
|
await queries.delete_role(self.connection, role_id=role_id)
|
|
|
|
async def get_roles_for_user(self, *, user_id: int) -> List[Role]:
|
|
rows = await queries.get_roles_for_user(self.connection, user_id=user_id)
|
|
return [Role(**self._convert_role_row(row)) for row in rows]
|
|
|
|
async def get_role_names_for_user(self, *, user_id: int) -> List[str]:
|
|
return [role.name for role in await self.get_roles_for_user(user_id=user_id)]
|
|
|
|
async def assign_role_to_user(self, *, user_id: int, role_id: int) -> None:
|
|
await queries.assign_role_to_user(
|
|
self.connection,
|
|
user_id=user_id,
|
|
role_id=role_id,
|
|
)
|
|
|
|
async def revoke_role_from_user(self, *, user_id: int, role_id: int) -> None:
|
|
await queries.revoke_role_from_user(
|
|
self.connection,
|
|
user_id=user_id,
|
|
role_id=role_id,
|
|
)
|
|
|
|
async def set_roles_for_user(self, *, user_id: int, role_ids: Iterable[int]) -> None:
|
|
role_ids = list(dict.fromkeys(role_ids))
|
|
async with self.connection.transaction():
|
|
await self.connection.execute(
|
|
"DELETE FROM user_roles WHERE user_id = $1",
|
|
user_id,
|
|
)
|
|
for role_id in role_ids:
|
|
await queries.assign_role_to_user(
|
|
self.connection,
|
|
user_id=user_id,
|
|
role_id=role_id,
|
|
)
|
|
|
|
async def user_has_role(self, *, user_id: int, role_name: str) -> bool:
|
|
row = await queries.user_has_role(
|
|
self.connection,
|
|
user_id=user_id,
|
|
role_name=role_name,
|
|
)
|
|
return bool(row and row.get("has_role"))
|