import asyncio from collections import defaultdict from typing import Dict, Optional, Set from fastapi import WebSocket from fastapi.encoders import jsonable_encoder from backend.api.deps import AuthUser from backend.modules.instances.models import Instance from backend.modules.instances.schemas import InstanceOut from backend.modules.users.models import RoleName class InstanceEventManager: """Tracks websocket connections and dispatches instance change events.""" def __init__(self) -> None: self._connections: Dict[int, Set[WebSocket]] = defaultdict(set) self._admins: Set[WebSocket] = set() self._ws_to_customer: Dict[WebSocket, int] = {} self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket, auth_user: AuthUser) -> None: await websocket.accept() customer_key = auth_user.customer_id or 0 async with self._lock: self._connections[customer_key].add(websocket) self._ws_to_customer[websocket] = customer_key if auth_user.role_name == RoleName.ADMIN.value: self._admins.add(websocket) async def disconnect(self, websocket: WebSocket) -> None: async with self._lock: customer_key = self._ws_to_customer.pop(websocket, None) if customer_key is not None and customer_key in self._connections: self._connections[customer_key].discard(websocket) if not self._connections[customer_key]: self._connections.pop(customer_key, None) self._admins.discard(websocket) async def broadcast(self, payload: dict, customer_id: Optional[int]) -> None: if customer_id is None: return payload_jsonable = jsonable_encoder(payload) async with self._lock: targets = set(self._admins) | set(self._connections.get(customer_id, set())) if not targets: return stale: list[WebSocket] = [] for ws in targets: try: await ws.send_json(payload_jsonable) except Exception: stale.append(ws) for ws in stale: await self.disconnect(ws) instance_event_manager = InstanceEventManager() def serialize_instance(instance: Instance) -> dict: return InstanceOut.model_validate(instance).model_dump() def build_removed_payload(instance: Instance) -> dict: return { "id": instance.id, "instance_id": instance.instance_id, "account_id": instance.account_id, "region": instance.region, "customer_id": instance.customer_id, "credential_id": instance.credential_id, "status": instance.status, } async def broadcast_instance_update(instance: Instance) -> None: await instance_event_manager.broadcast( {"type": "instance_update", "instance": serialize_instance(instance)}, instance.customer_id ) async def broadcast_instance_removed(payload: dict, customer_id: Optional[int]) -> None: await instance_event_manager.broadcast({"type": "instance_removed", "instance": payload}, customer_id)