"""Votes router: vote sessions, individual votes, result computation, and live updates.""" from __future__ import annotations import math import uuid from datetime import datetime, timedelta, timezone from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.database import get_db from app.models.protocol import FormulaConfig, VotingProtocol from app.models.user import DuniterIdentity from app.models.vote import Vote, VoteSession from app.routers.websocket import manager from app.schemas.vote import ( ThresholdDetailOut, VoteCreate, VoteOut, VoteResultOut, VoteSessionCreate, VoteSessionListOut, VoteSessionOut, ) from app.services.auth_service import get_current_identity from app.services.vote_service import ( close_session as svc_close_session, compute_result as svc_compute_result, get_threshold_details as svc_get_threshold_details, ) router = APIRouter() # ── Helpers ───────────────────────────────────────────────────────────────── async def _get_session(db: AsyncSession, session_id: uuid.UUID) -> VoteSession: """Fetch a vote session by ID with votes eagerly loaded, or raise 404.""" result = await db.execute( select(VoteSession) .options(selectinload(VoteSession.votes)) .where(VoteSession.id == session_id) ) session = result.scalar_one_or_none() if session is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Session de vote introuvable") return session async def _get_protocol_with_formula(db: AsyncSession, protocol_id: uuid.UUID) -> VotingProtocol: """Fetch a voting protocol with its formula config, or raise 404.""" result = await db.execute( select(VotingProtocol) .options(selectinload(VotingProtocol.formula_config)) .where(VotingProtocol.id == protocol_id) ) protocol = result.scalar_one_or_none() if protocol is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Protocole de vote introuvable") return protocol async def _check_session_expired(session: VoteSession, db: AsyncSession) -> VoteSession: """Check if a session has passed its ends_at deadline and auto-close it. If the session is still marked 'open' but the deadline has passed, close it and compute the final tally via the vote service. Returns the (possibly updated) session. """ if session.status == "open" and datetime.now(timezone.utc) > session.ends_at: try: result = await svc_close_session(session.id, db) # Reload session to get updated fields db_result = await db.execute( select(VoteSession) .options(selectinload(VoteSession.votes)) .where(VoteSession.id == session.id) ) session = db_result.scalar_one() # Broadcast session closed event await manager.broadcast(session.id, { "type": "session_closed", "session_id": str(session.id), "result": result.get("result"), "votes_for": result.get("votes_for", 0), "votes_against": result.get("votes_against", 0), "votes_total": result.get("votes_total", 0), }) except ValueError: pass # Session already closed by another process return session def _compute_threshold(formula: FormulaConfig, wot_size: int, votes_total: int) -> float: """Compute the WoT-based threshold using the core formula. Result = C + B^W + (M + (1-M) * (1 - (T/W)^G)) * max(0, T-C) Where: - C = constant_base - B = base_exponent - W = wot_size - M = majority_pct / 100 - G = gradient_exponent - T = votes_total (turnout) """ c = formula.constant_base b = formula.base_exponent w = max(wot_size, 1) m = formula.majority_pct / 100.0 g = formula.gradient_exponent t = votes_total # Inertia-based threshold base_power = b ** w if b > 0 else 0.0 turnout_ratio = min(t / w, 1.0) if w > 0 else 0.0 inertia = m + (1 - m) * (1 - turnout_ratio ** g) threshold = c + base_power + inertia * max(0, t - c) return threshold def _compute_result( session: VoteSession, formula: FormulaConfig, ) -> dict: """Compute the vote result based on tallies and formula. Returns a dict with threshold_required, result ("adopted" or "rejected"), and whether Smith/TechComm criteria are met. """ threshold = _compute_threshold(formula, session.wot_size, session.votes_total) # Main criterion: votes_for >= threshold main_pass = session.votes_for >= threshold # Smith criterion (if configured) smith_pass = True smith_threshold = None if formula.smith_exponent is not None and session.smith_size > 0: smith_threshold = math.ceil(session.smith_size ** formula.smith_exponent) smith_pass = session.smith_votes_for >= smith_threshold # TechComm criterion (if configured) techcomm_pass = True techcomm_threshold = None if formula.techcomm_exponent is not None and session.techcomm_size > 0: techcomm_threshold = math.ceil(session.techcomm_size ** formula.techcomm_exponent) techcomm_pass = session.techcomm_votes_for >= techcomm_threshold result = "adopted" if (main_pass and smith_pass and techcomm_pass) else "rejected" return { "threshold_required": threshold, "result": result, "smith_threshold": smith_threshold, "smith_pass": smith_pass, "techcomm_threshold": techcomm_threshold, "techcomm_pass": techcomm_pass, } # ── Routes ────────────────────────────────────────────────────────────────── @router.get("/sessions", response_model=list[VoteSessionListOut]) async def list_vote_sessions( db: AsyncSession = Depends(get_db), session_status: str | None = Query(default=None, alias="status", description="Filtrer par statut (open, closed, tallied)"), decision_id: uuid.UUID | None = Query(default=None, description="Filtrer par decision_id"), skip: int = Query(default=0, ge=0), limit: int = Query(default=50, ge=1, le=200), ) -> list[VoteSessionListOut]: """List all vote sessions with optional filters by status and decision_id.""" stmt = select(VoteSession) if session_status is not None: stmt = stmt.where(VoteSession.status == session_status) if decision_id is not None: stmt = stmt.where(VoteSession.decision_id == decision_id) stmt = stmt.order_by(VoteSession.created_at.desc()).offset(skip).limit(limit) result = await db.execute(stmt) sessions = result.scalars().all() # Auto-close expired sessions before returning checked_sessions = [] for s in sessions: s = await _check_session_expired(s, db) checked_sessions.append(s) return [VoteSessionListOut.model_validate(s) for s in checked_sessions] @router.post("/sessions", response_model=VoteSessionOut, status_code=status.HTTP_201_CREATED) async def create_vote_session( payload: VoteSessionCreate, db: AsyncSession = Depends(get_db), identity: DuniterIdentity = Depends(get_current_identity), ) -> VoteSessionOut: """Create a new vote session. The session duration is derived from the linked protocol's formula config. WoT/Smith/TechComm sizes should be snapshotted from the blockchain at creation time. """ # Validate protocol exists and get formula for duration protocol = await _get_protocol_with_formula(db, payload.voting_protocol_id) formula = protocol.formula_config starts_at = datetime.now(timezone.utc) ends_at = starts_at + timedelta(days=formula.duration_days) session = VoteSession( decision_id=payload.decision_id, item_version_id=payload.item_version_id, voting_protocol_id=payload.voting_protocol_id, starts_at=starts_at, ends_at=ends_at, # TODO: Snapshot actual WoT sizes from blockchain via Duniter RPC wot_size=0, smith_size=0, techcomm_size=0, ) db.add(session) await db.commit() await db.refresh(session) return VoteSessionOut.model_validate(session) @router.get("/sessions/{id}", response_model=VoteSessionOut) async def get_vote_session( id: uuid.UUID, db: AsyncSession = Depends(get_db), ) -> VoteSessionOut: """Get a vote session with current tallies. Automatically closes the session if its deadline has passed. """ session = await _get_session(db, id) session = await _check_session_expired(session, db) return VoteSessionOut.model_validate(session) @router.post("/sessions/{id}/close", response_model=VoteResultOut) async def close_vote_session( id: uuid.UUID, db: AsyncSession = Depends(get_db), identity: DuniterIdentity = Depends(get_current_identity), ) -> VoteResultOut: """Manually close a vote session and compute the final result. Requires authentication. The session must be in 'open' status. """ try: result = await svc_close_session(id, db) except ValueError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) # Broadcast session closed event via WebSocket await manager.broadcast(id, { "type": "session_closed", "session_id": str(id), "result": result.get("result"), "votes_for": result.get("votes_for", 0), "votes_against": result.get("votes_against", 0), "votes_total": result.get("votes_total", 0), }) return VoteResultOut( session_id=id, result=result.get("result"), threshold_required=float(result.get("threshold", 0)), votes_for=result.get("votes_for", 0), votes_against=result.get("votes_against", 0), votes_total=result.get("votes_total", 0), adopted=result.get("adopted", False), nuanced_breakdown=result.get("nuanced_breakdown"), ) @router.post("/sessions/{id}/vote", response_model=VoteOut, status_code=status.HTTP_201_CREATED) async def submit_vote( id: uuid.UUID, payload: VoteCreate, db: AsyncSession = Depends(get_db), identity: DuniterIdentity = Depends(get_current_identity), ) -> VoteOut: """Submit a vote to a session. Each identity can only vote once per session. Submitting again replaces the previous vote. The vote must include a cryptographic signature for on-chain proof. After submission, broadcasts a vote_update event via WebSocket. """ session = await _get_session(db, id) # Auto-close check session = await _check_session_expired(session, db) # Verify session is open if session.status != "open": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cette session de vote n'est pas ouverte", ) # Verify session hasn't ended if datetime.now(timezone.utc) > session.ends_at: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cette session de vote est terminee", ) # Check if voter already voted -- replace if so existing_result = await db.execute( select(Vote).where( Vote.session_id == session.id, Vote.voter_id == identity.id, ) ) existing_vote = existing_result.scalar_one_or_none() if existing_vote is not None: # Deactivate old vote (keep for audit trail) existing_vote.is_active = False # Update tallies: subtract old vote session.votes_total -= 1 if existing_vote.vote_value == "for": session.votes_for -= 1 if existing_vote.voter_is_smith: session.smith_votes_for -= 1 if existing_vote.voter_is_techcomm: session.techcomm_votes_for -= 1 elif existing_vote.vote_value == "against": session.votes_against -= 1 # Create new vote vote = Vote( session_id=session.id, voter_id=identity.id, vote_value=payload.vote_value, nuanced_level=payload.nuanced_level, comment=payload.comment, signature=payload.signature, signed_payload=payload.signed_payload, voter_wot_status=identity.wot_status, voter_is_smith=identity.is_smith, voter_is_techcomm=identity.is_techcomm, ) db.add(vote) # Update tallies: add new vote session.votes_total += 1 if payload.vote_value == "for": session.votes_for += 1 if identity.is_smith: session.smith_votes_for += 1 if identity.is_techcomm: session.techcomm_votes_for += 1 elif payload.vote_value == "against": session.votes_against += 1 await db.commit() await db.refresh(vote) # Broadcast vote update via WebSocket await manager.broadcast(session.id, { "type": "vote_update", "session_id": str(session.id), "votes_for": session.votes_for, "votes_against": session.votes_against, "votes_total": session.votes_total, }) return VoteOut.model_validate(vote) @router.get("/sessions/{id}/votes", response_model=list[VoteOut]) async def list_votes( id: uuid.UUID, db: AsyncSession = Depends(get_db), active_only: bool = Query(default=True, description="Ne montrer que les votes actifs"), ) -> list[VoteOut]: """List all votes in a session.""" # Verify session exists await _get_session(db, id) stmt = select(Vote).where(Vote.session_id == id) if active_only: stmt = stmt.where(Vote.is_active.is_(True)) stmt = stmt.order_by(Vote.created_at.asc()) result = await db.execute(stmt) votes = result.scalars().all() return [VoteOut.model_validate(v) for v in votes] @router.get("/sessions/{id}/threshold", response_model=ThresholdDetailOut) async def get_threshold_details( id: uuid.UUID, db: AsyncSession = Depends(get_db), ) -> ThresholdDetailOut: """Return computed threshold details for a vote session. Includes WoT/Smith/TechComm thresholds, pass/fail status, participation rate, and the formula parameters used. Automatically closes the session if its deadline has passed. """ session = await _get_session(db, id) session = await _check_session_expired(session, db) try: details = await svc_get_threshold_details(id, db) except ValueError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) return ThresholdDetailOut(**details) @router.get("/sessions/{id}/result") async def get_vote_result( id: uuid.UUID, db: AsyncSession = Depends(get_db), ) -> dict: """Compute and return the current result for a vote session. Uses the WoT threshold formula linked through the voting protocol. Returns current tallies, computed threshold, and whether the vote passes. Automatically closes the session if its deadline has passed. """ session = await _get_session(db, id) session = await _check_session_expired(session, db) # Get the protocol and formula protocol = await _get_protocol_with_formula(db, session.voting_protocol_id) formula = protocol.formula_config result_data = _compute_result(session, formula) return { "session_id": str(session.id), "status": session.status, "votes_for": session.votes_for, "votes_against": session.votes_against, "votes_total": session.votes_total, "wot_size": session.wot_size, "smith_size": session.smith_size, "techcomm_size": session.techcomm_size, "smith_votes_for": session.smith_votes_for, "techcomm_votes_for": session.techcomm_votes_for, **result_data, } @router.post("/sessions/{id}/tally", response_model=VoteResultOut) async def force_tally( id: uuid.UUID, db: AsyncSession = Depends(get_db), identity: DuniterIdentity = Depends(get_current_identity), ) -> VoteResultOut: """Force a recount of a vote session. Requires authentication. Useful after a chain snapshot update or when recalculation is needed. Works on any session status. """ try: result = await svc_compute_result(id, db) except ValueError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) # Broadcast tally update via WebSocket await manager.broadcast(id, { "type": "tally_update", "session_id": str(id), "result": result.get("result"), "votes_for": result.get("votes_for", 0), "votes_against": result.get("votes_against", 0), "votes_total": result.get("votes_total", 0), }) return VoteResultOut( session_id=id, result=result.get("result"), threshold_required=float(result.get("threshold", 0)), votes_for=result.get("votes_for", 0), votes_against=result.get("votes_against", 0), votes_total=result.get("votes_total", 0), adopted=result.get("adopted", False), nuanced_breakdown=result.get("nuanced_breakdown"), )