from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.database import get_db from app.models import Commune, TariffParams, Household from app.schemas import TariffComputeRequest, TariffComputeResponse, ImpactRowOut from app.engine.pricing import HouseholdData, compute_tariff, compute_impacts router = APIRouter() async def _load_commune_data( slug: str, db: AsyncSession ) -> tuple[list[HouseholdData], TariffParams]: """Load households and tariff params for a commune.""" result = await db.execute(select(Commune).where(Commune.slug == slug)) commune = result.scalar_one_or_none() if not commune: raise HTTPException(status_code=404, detail="Commune introuvable") params_result = await db.execute( select(TariffParams).where(TariffParams.commune_id == commune.id) ) params = params_result.scalar_one_or_none() if not params: raise HTTPException(status_code=404, detail="Paramètres tarifs manquants") hh_result = await db.execute( select(Household).where(Household.commune_id == commune.id) ) households_db = hh_result.scalars().all() if not households_db: raise HTTPException(status_code=400, detail="Aucun foyer importé pour cette commune") households = [ HouseholdData( volume_m3=h.volume_m3, status=h.status, price_paid_eur=h.price_paid_eur, ) for h in households_db ] return households, params @router.post("/compute", response_model=TariffComputeResponse) async def compute(data: TariffComputeRequest, db: AsyncSession = Depends(get_db)): households, params = await _load_commune_data(data.commune_slug, db) result = compute_tariff( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=data.vinf, vmax=params.vmax, pmax=params.pmax, a=data.a, b=data.b, c=data.c, d=data.d, e=data.e, ) p0, impacts = compute_impacts( households, recettes=params.recettes, abop=params.abop, abos=params.abos, vinf=data.vinf, vmax=params.vmax, pmax=params.pmax, a=data.a, b=data.b, c=data.c, d=data.d, e=data.e, ) return TariffComputeResponse( p0=result.p0, curve_volumes=result.curve_volumes, curve_prices_m3=result.curve_prices_m3, curve_bills_rp=result.curve_bills_rp, curve_bills_rs=result.curve_bills_rs, impacts=[ ImpactRowOut( volume=imp.volume, old_price=imp.old_price, new_price_rp=imp.new_price_rp, new_price_rs=imp.new_price_rs, ) for imp in impacts ], )