"""WebSocket router: live vote updates with authentication and heartbeat.""" from __future__ import annotations import asyncio import json import logging import uuid from datetime import datetime, timezone from typing import Any from fastapi import APIRouter, WebSocket, WebSocketDisconnect from sqlalchemy import select from app.database import async_session from app.models.user import Session as UserSession router = APIRouter() logger = logging.getLogger(__name__) # Heartbeat interval in seconds _HEARTBEAT_INTERVAL = 30 # Valid notification event types EVENT_TYPES = ( "vote_submitted", "vote_update", "session_closed", "tally_update", "decision_advanced", "mandate_updated", "document_changed", "sanctuary_archived", ) # ── Connection manager ────────────────────────────────────────────────────── class ConnectionManager: """Manages active WebSocket connections grouped by vote session ID.""" def __init__(self) -> None: # session_id -> list of connected websockets self._connections: dict[uuid.UUID, list[WebSocket]] = {} # websocket -> authenticated identity_id (or None for anonymous) self._authenticated: dict[WebSocket, uuid.UUID | None] = {} async def connect(self, websocket: WebSocket, session_id: uuid.UUID) -> None: """Accept a WebSocket connection and register it for a vote session.""" await websocket.accept() if session_id not in self._connections: self._connections[session_id] = [] self._connections[session_id].append(websocket) def disconnect(self, websocket: WebSocket, session_id: uuid.UUID) -> None: """Remove a WebSocket connection from the session group.""" if session_id in self._connections: self._connections[session_id] = [ ws for ws in self._connections[session_id] if ws is not websocket ] if not self._connections[session_id]: del self._connections[session_id] self._authenticated.pop(websocket, None) async def broadcast(self, session_id: uuid.UUID, data: dict[str, Any]) -> None: """Broadcast a message to all connections watching a given vote session.""" if session_id not in self._connections: return message = json.dumps(data, default=str) dead: list[WebSocket] = [] for ws in self._connections[session_id]: try: await ws.send_text(message) except Exception: dead.append(ws) # Clean up dead connections for ws in dead: self.disconnect(ws, session_id) async def broadcast_all(self, data: dict[str, Any]) -> None: """Broadcast a message to all connected WebSockets across all sessions.""" message = json.dumps(data, default=str) for session_id in list(self._connections.keys()): dead: list[WebSocket] = [] for ws in self._connections.get(session_id, []): try: await ws.send_text(message) except Exception: dead.append(ws) for ws in dead: self.disconnect(ws, session_id) def set_authenticated(self, websocket: WebSocket, identity_id: uuid.UUID) -> None: """Mark a WebSocket connection as authenticated.""" self._authenticated[websocket] = identity_id def is_authenticated(self, websocket: WebSocket) -> bool: """Check if a WebSocket connection is authenticated.""" return self._authenticated.get(websocket) is not None manager = ConnectionManager() # ── Authentication helper ────────────────────────────────────────────────── async def _validate_token(token: str) -> uuid.UUID | None: """Validate a bearer token and return the associated identity_id. Uses the same token hashing as auth_service but without FastAPI dependency injection (since WebSocket doesn't use Depends the same way). Parameters ---------- token: The raw bearer token from the query parameter. Returns ------- uuid.UUID | None The identity_id if valid, or None if invalid/expired. """ import hashlib token_hash = hashlib.sha256(token.encode()).hexdigest() try: async with async_session() as db: result = await db.execute( select(UserSession).where( UserSession.token_hash == token_hash, UserSession.expires_at > datetime.now(timezone.utc), ) ) session = result.scalar_one_or_none() if session is not None: return session.identity_id except Exception: logger.warning("Erreur lors de la validation du token WebSocket", exc_info=True) return None # ── Broadcast event helper (importable by other routers) ────────────────── async def broadcast_event( event_type: str, payload: dict[str, Any], session_id: uuid.UUID | None = None, ) -> None: """Broadcast a notification event to connected WebSocket clients. This function is designed to be imported and called from other routers (votes, decisions, mandates, etc.) to push real-time notifications. Parameters ---------- event_type: One of the valid EVENT_TYPES. payload: The event data to send. session_id: If provided, broadcast only to clients watching this specific session. If None, broadcast to all connected clients. """ data = { "event": event_type, "timestamp": datetime.now(timezone.utc).isoformat(), **payload, } if session_id is not None: await manager.broadcast(session_id, data) else: await manager.broadcast_all(data) # ── Heartbeat task ───────────────────────────────────────────────────────── async def _heartbeat(websocket: WebSocket) -> None: """Send periodic ping messages to keep the connection alive. Runs as a background task alongside the main message loop. Sends a JSON ping every _HEARTBEAT_INTERVAL seconds. """ try: while True: await asyncio.sleep(_HEARTBEAT_INTERVAL) try: await websocket.send_text( json.dumps({"event": "ping", "timestamp": datetime.now(timezone.utc).isoformat()}) ) except Exception: break except asyncio.CancelledError: pass # ── WebSocket endpoint ────────────────────────────────────────────────────── @router.websocket("/live") async def live_updates(websocket: WebSocket) -> None: """WebSocket endpoint for live vote session updates. Authentication (optional): Connect with ``?token=`` query parameter to authenticate. If the token is valid, the connection is marked as authenticated. If missing or invalid, the connection is accepted but unauthenticated. The client sends JSON messages: { "action": "subscribe", "session_id": "" } { "action": "unsubscribe", "session_id": "" } The server pushes events: { "event": "vote_update", "session_id": "...", ... } { "event": "session_closed", "session_id": "...", ... } { "event": "vote_submitted", ... } { "event": "decision_advanced", ... } { "event": "mandate_updated", ... } { "event": "document_changed", ... } { "event": "sanctuary_archived", ... } { "event": "ping", "timestamp": "..." } (heartbeat) """ # Extract token from query parameters token = websocket.query_params.get("token") await websocket.accept() subscribed_sessions: set[uuid.UUID] = set() # Authenticate if token provided if token: identity_id = await _validate_token(token) if identity_id is not None: manager.set_authenticated(websocket, identity_id) await websocket.send_text( json.dumps({"event": "authenticated", "identity_id": str(identity_id)}) ) logger.debug("WebSocket authentifie: identity=%s", identity_id) else: await websocket.send_text( json.dumps({"event": "auth_failed", "detail": "Token invalide ou expire"}) ) logger.debug("Echec authentification WebSocket (token invalide)") # Start heartbeat task heartbeat_task = asyncio.create_task(_heartbeat(websocket)) try: while True: raw = await websocket.receive_text() try: data = json.loads(raw) except json.JSONDecodeError: await websocket.send_text(json.dumps({"error": "JSON invalide"})) continue action = data.get("action") session_id_str = data.get("session_id") if not action or not session_id_str: await websocket.send_text( json.dumps({"error": "Champs 'action' et 'session_id' requis"}) ) continue try: session_id = uuid.UUID(session_id_str) except ValueError: await websocket.send_text(json.dumps({"error": "session_id invalide"})) continue if action == "subscribe": if session_id not in subscribed_sessions: # Register this websocket in the manager for this session if session_id not in manager._connections: manager._connections[session_id] = [] manager._connections[session_id].append(websocket) subscribed_sessions.add(session_id) await websocket.send_text( json.dumps({"event": "subscribed", "session_id": str(session_id)}) ) elif action == "unsubscribe": if session_id in subscribed_sessions: manager.disconnect(websocket, session_id) subscribed_sessions.discard(session_id) await websocket.send_text( json.dumps({"event": "unsubscribed", "session_id": str(session_id)}) ) else: await websocket.send_text( json.dumps({"error": f"Action inconnue: {action}"}) ) except WebSocketDisconnect: # Clean up all subscriptions for this client for session_id in subscribed_sessions: manager.disconnect(websocket, session_id) logger.debug("WebSocket deconnecte, %d souscriptions nettoyees", len(subscribed_sessions)) finally: # Cancel the heartbeat task heartbeat_task.cancel() try: await heartbeat_task except asyncio.CancelledError: pass