86 lines
3.1 KiB
Python
86 lines
3.1 KiB
Python
import asyncio
|
|
from collections import defaultdict
|
|
from typing import Dict, Optional, Set
|
|
|
|
from fastapi import WebSocket
|
|
from fastapi.encoders import jsonable_encoder
|
|
|
|
from backend.api.deps import AuthUser
|
|
from backend.modules.instances.models import Instance
|
|
from backend.modules.instances.schemas import InstanceOut
|
|
from backend.modules.users.models import RoleName
|
|
|
|
|
|
class InstanceEventManager:
|
|
"""Tracks websocket connections and dispatches instance change events."""
|
|
|
|
def __init__(self) -> None:
|
|
self._connections: Dict[int, Set[WebSocket]] = defaultdict(set)
|
|
self._admins: Set[WebSocket] = set()
|
|
self._ws_to_customer: Dict[WebSocket, int] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def connect(self, websocket: WebSocket, auth_user: AuthUser) -> None:
|
|
await websocket.accept()
|
|
customer_key = auth_user.customer_id or 0
|
|
async with self._lock:
|
|
self._connections[customer_key].add(websocket)
|
|
self._ws_to_customer[websocket] = customer_key
|
|
if auth_user.role_name == RoleName.ADMIN.value:
|
|
self._admins.add(websocket)
|
|
|
|
async def disconnect(self, websocket: WebSocket) -> None:
|
|
async with self._lock:
|
|
customer_key = self._ws_to_customer.pop(websocket, None)
|
|
if customer_key is not None and customer_key in self._connections:
|
|
self._connections[customer_key].discard(websocket)
|
|
if not self._connections[customer_key]:
|
|
self._connections.pop(customer_key, None)
|
|
self._admins.discard(websocket)
|
|
|
|
async def broadcast(self, payload: dict, customer_id: Optional[int]) -> None:
|
|
if customer_id is None:
|
|
return
|
|
payload_jsonable = jsonable_encoder(payload)
|
|
async with self._lock:
|
|
targets = set(self._admins) | set(self._connections.get(customer_id, set()))
|
|
if not targets:
|
|
return
|
|
stale: list[WebSocket] = []
|
|
for ws in targets:
|
|
try:
|
|
await ws.send_json(payload_jsonable)
|
|
except Exception:
|
|
stale.append(ws)
|
|
for ws in stale:
|
|
await self.disconnect(ws)
|
|
|
|
|
|
instance_event_manager = InstanceEventManager()
|
|
|
|
|
|
def serialize_instance(instance: Instance) -> dict:
|
|
return InstanceOut.model_validate(instance).model_dump()
|
|
|
|
|
|
def build_removed_payload(instance: Instance) -> dict:
|
|
return {
|
|
"id": instance.id,
|
|
"instance_id": instance.instance_id,
|
|
"account_id": instance.account_id,
|
|
"region": instance.region,
|
|
"customer_id": instance.customer_id,
|
|
"credential_id": instance.credential_id,
|
|
"status": instance.status,
|
|
}
|
|
|
|
|
|
async def broadcast_instance_update(instance: Instance) -> None:
|
|
await instance_event_manager.broadcast(
|
|
{"type": "instance_update", "instance": serialize_instance(instance)}, instance.customer_id
|
|
)
|
|
|
|
|
|
async def broadcast_instance_removed(payload: dict, customer_id: Optional[int]) -> None:
|
|
await instance_event_manager.broadcast({"type": "instance_removed", "instance": payload}, customer_id)
|