from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update from app.database import get_db from app.models import Commune, Household, Vote, TariffParams, AdminUser from app.schemas import VoteCreate, VoteOut, MedianOut, TariffComputeResponse, ImpactRowOut from app.services.auth_service import get_current_citizen, get_current_admin from app.engine.pricing import HouseholdData, compute_p0, compute_tariff, compute_impacts from app.engine.current_model import compute_linear_tariff from app.engine.median import VoteParams, compute_median router = APIRouter() async def _get_commune_by_slug(slug: str, db: AsyncSession) -> Commune: result = await db.execute(select(Commune).where(Commune.slug == slug)) commune = result.scalar_one_or_none() if not commune: raise HTTPException(status_code=404, detail="Commune introuvable") return commune async def _load_commune_context(commune_id: int, db: AsyncSession): """Load tariff params and households for a commune.""" params_result = await db.execute( select(TariffParams).where(TariffParams.commune_id == commune_id) ) params = params_result.scalar_one_or_none() hh_result = await db.execute( select(Household).where(Household.commune_id == commune_id) ) households_db = hh_result.scalars().all() households = [ HouseholdData(volume_m3=h.volume_m3, status=h.status, price_paid_eur=h.price_paid_eur) for h in households_db ] return params, households # ── Public endpoint: overlay of all vote curves for citizens ── @router.get("/communes/{slug}/votes/current/overlay") async def current_overlay(slug: str, db: AsyncSession = Depends(get_db)): """Public: returns all active vote curves (params only, no auth required).""" commune = await _get_commune_by_slug(slug, db) result = await db.execute( select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True) ) votes = result.scalars().all() return [ {"vinf": v.vinf, "a": v.a, "b": v.b, "c": v.c, "d": v.d, "e": v.e, "computed_p0": v.computed_p0} for v in votes ] # ── Public endpoint: current median curve for citizens ── @router.get("/communes/{slug}/votes/current") async def current_curve(slug: str, db: AsyncSession = Depends(get_db)): """ Public endpoint: returns the current median curve + baseline linear model. No auth required — this is what citizens see when they visit a commune page. Always returns the baseline linear model. Returns the median Bézier curve only if votes exist. """ commune = await _get_commune_by_slug(slug, db) params, households = await _load_commune_context(commune.id, db) if not params or not households: return {"has_votes": False, "vote_count": 0} # Always compute the baseline linear model baseline = compute_linear_tariff( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vmax=params.vmax, ) baseline_data = { "p0_linear": baseline.p0, "baseline_volumes": baseline.curve_volumes, "baseline_bills_rp": baseline.curve_bills_rp, "baseline_bills_rs": baseline.curve_bills_rs, "baseline_price_m3_rp": baseline.curve_price_m3_rp, "baseline_price_m3_rs": baseline.curve_price_m3_rs, } # Tariff params for the frontend tariff_params = { "recettes": params.recettes, "abop": params.abop, "abos": params.abos, "pmax": params.pmax, "vmax": params.vmax, } # Get active votes result = await db.execute( select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True) ) votes = result.scalars().all() # Published curve from admin (if any) published = None if params.published_vinf is not None: published = { "vinf": params.published_vinf, "a": params.published_a, "b": params.published_b, "c": params.published_c, "d": params.published_d, "e": params.published_e, "p0": params.published_p0, } if not votes: # Use published curve if available, otherwise default if published: dv, da, db_, dc, dd, de = published["vinf"], published["a"], published["b"], published["c"], published["d"], published["e"] else: dv, da, db_, dc, dd, de = 400, 0.5, 0.5, 0.5, 0.5, 0.5 default_tariff = compute_tariff( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=dv, vmax=params.vmax, pmax=params.pmax, a=da, b=db_, c=dc, d=dd, e=de, ) _, default_impacts = compute_impacts( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=dv, vmax=params.vmax, pmax=params.pmax, a=da, b=db_, c=dc, d=dd, e=de, ) return { "has_votes": False, "vote_count": 0, "params": tariff_params, "published": published, "median": { "vinf": dv, "a": da, "b": db_, "c": dc, "d": dd, "e": de, }, "p0": default_tariff.p0, "curve_volumes": default_tariff.curve_volumes, "curve_prices_m3": default_tariff.curve_prices_m3, "curve_bills_rp": default_tariff.curve_bills_rp, "curve_bills_rs": default_tariff.curve_bills_rs, "impacts": [ {"volume": imp.volume, "old_price": imp.old_price, "new_price_rp": imp.new_price_rp, "new_price_rs": imp.new_price_rs} for imp in default_impacts ], **baseline_data, } # Compute median vote_params = [ VoteParams(vinf=v.vinf, a=v.a, b=v.b, c=v.c, d=v.d, e=v.e) for v in votes ] median = compute_median(vote_params) # Compute full tariff for the median tariff = compute_tariff( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=median.vinf, vmax=params.vmax, pmax=params.pmax, a=median.a, b=median.b, c=median.c, d=median.d, e=median.e, ) _, impacts = compute_impacts( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=median.vinf, vmax=params.vmax, pmax=params.pmax, a=median.a, b=median.b, c=median.c, d=median.d, e=median.e, ) return { "has_votes": True, "vote_count": len(votes), "params": tariff_params, "published": published, "median": { "vinf": median.vinf, "a": median.a, "b": median.b, "c": median.c, "d": median.d, "e": median.e, }, "p0": tariff.p0, "curve_volumes": tariff.curve_volumes, "curve_prices_m3": tariff.curve_prices_m3, "curve_bills_rp": tariff.curve_bills_rp, "curve_bills_rs": tariff.curve_bills_rs, "impacts": [ {"volume": imp.volume, "old_price": imp.old_price, "new_price_rp": imp.new_price_rp, "new_price_rs": imp.new_price_rs} for imp in impacts ], **baseline_data, } # ── Citizen: submit vote ── @router.post("/communes/{slug}/votes", response_model=VoteOut) async def submit_vote( slug: str, data: VoteCreate, db: AsyncSession = Depends(get_db), household: Household = Depends(get_current_citizen), ): commune = await _get_commune_by_slug(slug, db) if household.commune_id != commune.id: raise HTTPException(status_code=403, detail="Accès interdit à cette commune") # Deactivate previous votes await db.execute( update(Vote) .where(Vote.household_id == household.id, Vote.is_active == True) .values(is_active=False) ) params, households = await _load_commune_context(commune.id, db) computed_p0 = compute_p0( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=data.vinf, vmax=params.vmax, pmax=params.pmax, a=data.a, b=data.b, c=data.c, d=data.d, e=data.e, ) if params else None vote = Vote( commune_id=commune.id, household_id=household.id, vinf=data.vinf, a=data.a, b=data.b, c=data.c, d=data.d, e=data.e, computed_p0=computed_p0, ) db.add(vote) household.has_voted = True await db.commit() await db.refresh(vote) return vote # ── Admin: list votes ── @router.get("/communes/{slug}/votes", response_model=list[VoteOut]) async def list_votes( slug: str, db: AsyncSession = Depends(get_db), admin: AdminUser = Depends(get_current_admin), ): commune = await _get_commune_by_slug(slug, db) result = await db.execute( select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True) ) return result.scalars().all() # ── Admin: median ── @router.get("/communes/{slug}/votes/median", response_model=MedianOut) async def vote_median( slug: str, db: AsyncSession = Depends(get_db), admin: AdminUser = Depends(get_current_admin), ): commune = await _get_commune_by_slug(slug, db) result = await db.execute( select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True) ) votes = result.scalars().all() if not votes: raise HTTPException(status_code=404, detail="Aucun vote actif") vote_params = [ VoteParams(vinf=v.vinf, a=v.a, b=v.b, c=v.c, d=v.d, e=v.e) for v in votes ] median = compute_median(vote_params) params, households = await _load_commune_context(commune.id, db) computed_p0 = compute_p0( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=median.vinf, vmax=params.vmax, pmax=params.pmax, a=median.a, b=median.b, c=median.c, d=median.d, e=median.e, ) if params else 0 return MedianOut( vinf=median.vinf, a=median.a, b=median.b, c=median.c, d=median.d, e=median.e, computed_p0=computed_p0, vote_count=len(votes), ) # ── Admin: overlay ── @router.get("/communes/{slug}/votes/overlay") async def vote_overlay( slug: str, db: AsyncSession = Depends(get_db), admin: AdminUser = Depends(get_current_admin), ): commune = await _get_commune_by_slug(slug, db) result = await db.execute( select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True) ) votes = result.scalars().all() return [ {"id": v.id, "vinf": v.vinf, "a": v.a, "b": v.b, "c": v.c, "d": v.d, "e": v.e, "computed_p0": v.computed_p0} for v in votes ]