2025-12-04 10:04:21 +08:00

144 lines
4.7 KiB
Python

from __future__ import annotations
import json
from typing import Iterable, List, Optional
from app.db.errors import EntityDoesNotExist
from app.db.queries.queries import queries
from app.db.repositories.base import BaseRepository
from app.models.domain.roles import Role
class RolesRepository(BaseRepository):
def _convert_role_row(self, row) -> dict:
permissions = row.get("permissions") if row else []
if isinstance(permissions, str):
try:
permissions = json.loads(permissions)
except ValueError:
permissions = []
permissions = permissions or []
return {
**row,
"permissions": permissions,
}
async def list_roles(self) -> List[Role]:
rows = await queries.list_roles(self.connection)
return [Role(**self._convert_role_row(row)) for row in rows]
async def get_role_by_id(self, role_id: int) -> Role:
row = await queries.get_role_by_id(self.connection, role_id=role_id)
if not row:
raise EntityDoesNotExist(f"role {role_id} does not exist")
return Role(**self._convert_role_row(row))
async def get_role_by_name(self, *, name: str) -> Optional[Role]:
row = await self.connection.fetchrow(
"""
SELECT id, name, description, permissions, created_at, updated_at
FROM roles
WHERE name = $1
""",
name,
)
if not row:
return None
return Role(**self._convert_role_row(dict(row)))
async def create_role(
self,
*,
name: str,
description: Optional[str] = "",
permissions: Optional[Iterable[str]] = None,
) -> Role:
row = await queries.create_role(
self.connection,
name=name,
description=description or "",
permissions=list(permissions or []),
)
return Role(**self._convert_role_row(row))
async def update_role(
self,
*,
role_id: int,
name: Optional[str] = None,
description: Optional[str] = None,
permissions: Optional[Iterable[str]] = None,
) -> Role:
row = await queries.update_role(
self.connection,
role_id=role_id,
name=name,
description=description,
permissions=list(permissions) if permissions is not None else None,
)
if not row:
raise EntityDoesNotExist(f"role {role_id} does not exist")
return Role(**self._convert_role_row(row))
async def ensure_role(
self,
*,
name: str,
description: Optional[str] = "",
permissions: Optional[Iterable[str]] = None,
) -> Role:
existing = await self.get_role_by_name(name=name)
if existing:
return existing
return await self.create_role(
name=name,
description=description or "",
permissions=permissions or [],
)
async def delete_role(self, *, role_id: int) -> None:
await queries.delete_role(self.connection, role_id=role_id)
async def get_roles_for_user(self, *, user_id: int) -> List[Role]:
rows = await queries.get_roles_for_user(self.connection, user_id=user_id)
return [Role(**self._convert_role_row(row)) for row in rows]
async def get_role_names_for_user(self, *, user_id: int) -> List[str]:
return [role.name for role in await self.get_roles_for_user(user_id=user_id)]
async def assign_role_to_user(self, *, user_id: int, role_id: int) -> None:
await queries.assign_role_to_user(
self.connection,
user_id=user_id,
role_id=role_id,
)
async def revoke_role_from_user(self, *, user_id: int, role_id: int) -> None:
await queries.revoke_role_from_user(
self.connection,
user_id=user_id,
role_id=role_id,
)
async def set_roles_for_user(self, *, user_id: int, role_ids: Iterable[int]) -> None:
role_ids = list(dict.fromkeys(role_ids))
async with self.connection.transaction():
await self.connection.execute(
"DELETE FROM user_roles WHERE user_id = $1",
user_id,
)
for role_id in role_ids:
await queries.assign_role_to_user(
self.connection,
user_id=user_id,
role_id=role_id,
)
async def user_has_role(self, *, user_id: int, role_name: str) -> bool:
row = await queries.user_has_role(
self.connection,
user_id=user_id,
role_name=role_name,
)
return bool(row and row.get("has_role"))