Initial commit: SejeteralO water tarification platform

Full-stack app for participatory water pricing using Bezier curves.
- Backend: FastAPI + SQLAlchemy + SQLite with JWT auth
- Frontend: Nuxt 4 + TypeScript with interactive SVG editor
- Math engine: cubic Bezier tarification with Cardano solver
- Admin: commune management, household import, vote monitoring, CMS
- Citizen: interactive curve editor, vote submission
- Docker-compose deployment ready

Includes fixes for:
- Impact table snake_case/camelCase property mismatch
- CMS content backend API + frontend editor (was stub)
- Admin route protection middleware
- Public content display on commune page
- Vote confirmation page link fix

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Yvv
2026-02-21 15:26:02 +01:00
commit b30e54a8f7
67 changed files with 16723 additions and 0 deletions

4
backend/.env.example Normal file
View File

@@ -0,0 +1,4 @@
DATABASE_URL=sqlite+aiosqlite:///./sejeteralo.db
SECRET_KEY=change-me-in-production-with-a-real-secret-key
DEBUG=true
CORS_ORIGINS=["http://localhost:3000"]

12
backend/Dockerfile Normal file
View File

@@ -0,0 +1,12 @@
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

117
backend/alembic.ini Normal file
View File

@@ -0,0 +1,117 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
backend/alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

54
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,54 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from app.database import Base
from app.models import models # noqa: F401 - import to register models
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
# Use sqlite for dev
config.set_main_option("sqlalchemy.url", "sqlite:///./sejeteralo.db")
def run_migrations_offline() -> None:
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,130 @@
"""initial schema
Revision ID: 25f534648ea7
Revises:
Create Date: 2026-02-21 05:29:28.228738
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '25f534648ea7'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('admin_users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('email', sa.String(length=200), nullable=False),
sa.Column('hashed_password', sa.String(length=200), nullable=False),
sa.Column('full_name', sa.String(length=200), nullable=True),
sa.Column('role', sa.String(length=20), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_admin_users_email'), 'admin_users', ['email'], unique=True)
op.create_index(op.f('ix_admin_users_id'), 'admin_users', ['id'], unique=False)
op.create_table('communes',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(length=200), nullable=False),
sa.Column('slug', sa.String(length=200), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_communes_id'), 'communes', ['id'], unique=False)
op.create_index(op.f('ix_communes_slug'), 'communes', ['slug'], unique=True)
op.create_table('admin_commune',
sa.Column('admin_id', sa.Integer(), nullable=False),
sa.Column('commune_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['admin_id'], ['admin_users.id'], ),
sa.ForeignKeyConstraint(['commune_id'], ['communes.id'], ),
sa.PrimaryKeyConstraint('admin_id', 'commune_id')
)
op.create_table('commune_contents',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('commune_id', sa.Integer(), nullable=False),
sa.Column('slug', sa.String(length=200), nullable=False),
sa.Column('title', sa.String(length=200), nullable=True),
sa.Column('body_markdown', sa.Text(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['commune_id'], ['communes.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('commune_id', 'slug', name='uq_content_commune_slug')
)
op.create_index(op.f('ix_commune_contents_id'), 'commune_contents', ['id'], unique=False)
op.create_table('households',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('commune_id', sa.Integer(), nullable=False),
sa.Column('identifier', sa.String(length=200), nullable=False),
sa.Column('status', sa.String(length=10), nullable=False),
sa.Column('volume_m3', sa.Float(), nullable=False),
sa.Column('price_paid_eur', sa.Float(), nullable=True),
sa.Column('auth_code', sa.String(length=8), nullable=False),
sa.Column('has_voted', sa.Boolean(), nullable=True),
sa.ForeignKeyConstraint(['commune_id'], ['communes.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('commune_id', 'identifier', name='uq_household_commune_identifier')
)
op.create_index(op.f('ix_households_auth_code'), 'households', ['auth_code'], unique=True)
op.create_index(op.f('ix_households_id'), 'households', ['id'], unique=False)
op.create_table('tariff_params',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('commune_id', sa.Integer(), nullable=False),
sa.Column('abop', sa.Float(), nullable=True),
sa.Column('abos', sa.Float(), nullable=True),
sa.Column('recettes', sa.Float(), nullable=True),
sa.Column('pmax', sa.Float(), nullable=True),
sa.Column('vmax', sa.Float(), nullable=True),
sa.ForeignKeyConstraint(['commune_id'], ['communes.id'], ),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('commune_id')
)
op.create_index(op.f('ix_tariff_params_id'), 'tariff_params', ['id'], unique=False)
op.create_table('votes',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('commune_id', sa.Integer(), nullable=False),
sa.Column('household_id', sa.Integer(), nullable=False),
sa.Column('vinf', sa.Float(), nullable=False),
sa.Column('a', sa.Float(), nullable=False),
sa.Column('b', sa.Float(), nullable=False),
sa.Column('c', sa.Float(), nullable=False),
sa.Column('d', sa.Float(), nullable=False),
sa.Column('e', sa.Float(), nullable=False),
sa.Column('computed_p0', sa.Float(), nullable=True),
sa.Column('submitted_at', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.ForeignKeyConstraint(['commune_id'], ['communes.id'], ),
sa.ForeignKeyConstraint(['household_id'], ['households.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_votes_id'), 'votes', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_votes_id'), table_name='votes')
op.drop_table('votes')
op.drop_index(op.f('ix_tariff_params_id'), table_name='tariff_params')
op.drop_table('tariff_params')
op.drop_index(op.f('ix_households_id'), table_name='households')
op.drop_index(op.f('ix_households_auth_code'), table_name='households')
op.drop_table('households')
op.drop_index(op.f('ix_commune_contents_id'), table_name='commune_contents')
op.drop_table('commune_contents')
op.drop_table('admin_commune')
op.drop_index(op.f('ix_communes_slug'), table_name='communes')
op.drop_index(op.f('ix_communes_id'), table_name='communes')
op.drop_table('communes')
op.drop_index(op.f('ix_admin_users_id'), table_name='admin_users')
op.drop_index(op.f('ix_admin_users_email'), table_name='admin_users')
op.drop_table('admin_users')
# ### end Alembic commands ###

0
backend/app/__init__.py Normal file
View File

19
backend/app/config.py Normal file
View File

@@ -0,0 +1,19 @@
from pydantic_settings import BaseSettings
from pathlib import Path
class Settings(BaseSettings):
APP_NAME: str = "SejeteralO"
DEBUG: bool = True
DATABASE_URL: str = "sqlite+aiosqlite:///./sejeteralo.db"
SECRET_KEY: str = "change-me-in-production-with-a-real-secret-key"
ALGORITHM: str = "HS256"
ADMIN_TOKEN_EXPIRE_HOURS: int = 24
CITIZEN_TOKEN_EXPIRE_HOURS: int = 4
BASE_DIR: Path = Path(__file__).resolve().parent.parent
CORS_ORIGINS: list[str] = ["http://localhost:3000", "http://localhost:3001", "http://localhost:3009"]
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
settings = Settings()

21
backend/app/database.py Normal file
View File

@@ -0,0 +1,21 @@
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase
from app.config import settings
engine = create_async_engine(settings.DATABASE_URL, echo=settings.DEBUG)
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
class Base(DeclarativeBase):
pass
async def get_db():
async with async_session() as session:
yield session
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

View File

@@ -0,0 +1,23 @@
from app.engine.integrals import compute_integrals
from app.engine.pricing import (
HouseholdData,
TariffResult,
compute_p0,
compute_tariff,
compute_impacts,
)
from app.engine.current_model import compute_linear_tariff, LinearTariffResult
from app.engine.median import VoteParams, compute_median
__all__ = [
"compute_integrals",
"HouseholdData",
"TariffResult",
"compute_p0",
"compute_tariff",
"compute_impacts",
"compute_linear_tariff",
"LinearTariffResult",
"VoteParams",
"compute_median",
]

View File

@@ -0,0 +1,66 @@
"""
Current (linear) pricing model.
Ported from eau.py:256-354 (CurrentModel).
Pure Python + numpy, no matplotlib.
"""
from dataclasses import dataclass
from app.engine.pricing import HouseholdData
@dataclass
class LinearTariffResult:
"""Result of the linear tariff computation."""
p0: float # flat price per m³
curve_volumes: list[float]
curve_bills_rp: list[float]
curve_bills_rs: list[float]
curve_price_m3_rp: list[float]
curve_price_m3_rs: list[float]
def compute_linear_tariff(
households: list[HouseholdData],
recettes: float,
abop: float,
abos: float,
vmax: float = 2100,
nbpts: int = 200,
) -> LinearTariffResult:
"""
Compute the linear (current) pricing model.
p0 = (recettes - Σ abo) / Σ volume
"""
total_abo = 0.0
total_volume = 0.0
for h in households:
abo = abos if h.status == "RS" else abop
total_abo += abo
total_volume += max(h.volume_m3, 1e-5)
if total_abo >= recettes or total_volume == 0:
p0 = 0.0
else:
p0 = (recettes - total_abo) / total_volume
# Generate curves
import numpy as np
vv = np.linspace(1e-5, vmax, nbpts)
bills_rp = abop + p0 * vv
bills_rs = abos + p0 * vv
price_m3_rp = abop / vv + p0
price_m3_rs = abos / vv + p0
return LinearTariffResult(
p0=p0,
curve_volumes=vv.tolist(),
curve_bills_rp=bills_rp.tolist(),
curve_bills_rs=bills_rs.tolist(),
curve_price_m3_rp=price_m3_rp.tolist(),
curve_price_m3_rs=price_m3_rs.tolist(),
)

View File

@@ -0,0 +1,118 @@
"""
Integral computation for Bézier tariff curves.
Ported from eau.py:169-211 (NewModel.computeIntegrals).
Pure Python + numpy, no matplotlib.
"""
import numpy as np
def compute_integrals(
volume: float,
vinf: float,
vmax: float,
pmax: float,
a: float,
b: float,
c: float,
d: float,
e: float,
) -> tuple[float, float, float]:
"""
Compute (alpha1, alpha2, beta2) for a given consumption volume.
The total bill for a household consuming `volume` m³ is:
bill = abo + alpha1 * p0 + alpha2 * p0 + beta2
where p0 is the inflection price (computed separately to balance revenue).
Args:
volume: consumption in m³ for this household
vinf: inflection volume separating the two tiers
vmax: maximum volume (price = pmax at this volume)
pmax: maximum price per m³
a, b: shape parameters for tier 1 Bézier curve
c, d, e: shape parameters for tier 2 Bézier curve
Returns:
(alpha1, alpha2, beta2) tuple
"""
if volume <= vinf:
# Tier 1 only
T = _solve_tier1_t(volume, vinf, b)
alpha1 = _compute_alpha1(T, vinf, a, b)
return alpha1, 0.0, 0.0
else:
# Full tier 1 (T=1) + partial tier 2
alpha1 = _compute_alpha1(1.0, vinf, a, b)
# Tier 2
wmax = vmax - vinf
T = _solve_tier2_t(volume - vinf, wmax, c, d)
uu = _compute_uu(T, c, d, e)
alpha2 = (volume - vinf) - 3 * uu * wmax
beta2 = 3 * pmax * wmax * uu
return alpha1, alpha2, beta2
def _solve_tier1_t(volume: float, vinf: float, b: float) -> float:
"""Find T such that v(T) = volume for tier 1."""
if volume == 0:
return 0.0
if volume >= vinf:
return 1.0
# Solve: vinf * [(1 - 3b) * T³ + 3b * T²] = volume
# => (1-3b) * T³ + 3b * T² - volume/vinf = 0
p = [1 - 3 * b, 3 * b, 0, -volume / vinf]
roots = np.roots(p)
roots = np.unique(roots)
real_roots = np.real(roots[np.isreal(roots)])
mask = (real_roots <= 1.0) & (real_roots >= 0.0)
return float(real_roots[mask][0])
def _solve_tier2_t(w: float, wmax: float, c: float, d: float) -> float:
"""Find T such that w(T) = w for tier 2, where w = volume - vinf."""
if w == 0:
return 0.0
if w >= wmax:
return 1.0
# Solve: wmax * [(3(c+d-cd)-2)*T³ + 3(1-2c-d+cd)*T² + 3c*T] = w
p = [
3 * (c + d - c * d) - 2,
3 * (1 - 2 * c - d + c * d),
3 * c,
-w / wmax,
]
roots = np.roots(p)
roots = np.unique(roots)
real_roots = np.real(roots[np.isreal(roots)])
mask = (real_roots <= 1.0 + 1e-10) & (real_roots >= -1e-10)
if not mask.any():
# Fallback: closest root to [0,1]
return float(np.clip(np.real(roots[0]), 0.0, 1.0))
return float(np.clip(real_roots[mask][0], 0.0, 1.0))
def _compute_alpha1(T: float, vinf: float, a: float, b: float) -> float:
"""Compute alpha1 coefficient for tier 1."""
return 3 * vinf * (
T**6 / 6 * (-9 * a * b + 3 * a + 6 * b - 2)
+ T**5 / 5 * (24 * a * b - 6 * a - 13 * b + 3)
+ 3 * T**4 / 4 * (-7 * a * b + a + 2 * b)
+ T**3 / 3 * 6 * a * b
)
def _compute_uu(T: float, c: float, d: float, e: float) -> float:
"""Compute the uu intermediate value for tier 2."""
return (
(-3 * c * d + 9 * e * c * d + 3 * c - 9 * e * c + 3 * d - 9 * e * d + 6 * e - 2) * T**6 / 6
+ (2 * c * d - 15 * e * c * d - 4 * c + 21 * e * c - 2 * d + 15 * e * d - 12 * e + 2) * T**5 / 5
+ (6 * e * c * d + c - 15 * e * c - 6 * e * d + 6 * e) * T**4 / 4
+ (3 * e * c) * T**3 / 3
)

View File

@@ -0,0 +1,48 @@
"""
Median computation for vote parameters.
Computes the element-wise median of (vinf, a, b, c, d, e) across all active votes.
This parametric median is chosen over geometric median because:
- It's transparent and politically explainable
- The result is itself a valid set of Bézier parameters
"""
import numpy as np
from dataclasses import dataclass
@dataclass
class VoteParams:
"""The 6 citizen-adjustable parameters."""
vinf: float
a: float
b: float
c: float
d: float
e: float
def compute_median(votes: list[VoteParams]) -> VoteParams | None:
"""
Compute element-wise median of vote parameters.
Returns None if no votes provided.
"""
if not votes:
return None
vinfs = [v.vinf for v in votes]
a_s = [v.a for v in votes]
b_s = [v.b for v in votes]
c_s = [v.c for v in votes]
d_s = [v.d for v in votes]
e_s = [v.e for v in votes]
return VoteParams(
vinf=float(np.median(vinfs)),
a=float(np.median(a_s)),
b=float(np.median(b_s)),
c=float(np.median(c_s)),
d=float(np.median(d_s)),
e=float(np.median(e_s)),
)

View File

@@ -0,0 +1,196 @@
"""
Pricing computation for Bézier tariff model.
Ported from eau.py:120-167 (NewModel.updateComputation).
Pure Python + numpy, no matplotlib.
"""
import numpy as np
from dataclasses import dataclass
from app.engine.integrals import compute_integrals
@dataclass
class HouseholdData:
"""Minimal household data needed for computation."""
volume_m3: float
status: str # "RS", "RP", or "PRO"
price_paid_eur: float = 0.0
@dataclass
class TariffResult:
"""Result of a full tariff computation."""
p0: float
curve_volumes: list[float]
curve_prices_m3: list[float]
curve_bills_rp: list[float]
curve_bills_rs: list[float]
household_bills: list[float] # projected bill for each household
@dataclass
class ImpactRow:
"""Price impact for a specific volume level."""
volume: float
old_price: float
new_price_rp: float
new_price_rs: float
def compute_p0(
households: list[HouseholdData],
recettes: float,
abop: float,
abos: float,
vinf: float,
vmax: float,
pmax: float,
a: float,
b: float,
c: float,
d: float,
e: float,
) -> float:
"""
Compute p0 (inflection price) that balances total revenue.
p0 = (R - Σ(abo + β₂)) / Σ(α₁ + α₂)
"""
total_abo = 0.0
total_alpha = 0.0
total_beta = 0.0
for h in households:
abo = abos if h.status == "RS" else abop
total_abo += abo
vol = max(h.volume_m3, 1e-5) # avoid div by 0
alpha1, alpha2, beta2 = compute_integrals(vol, vinf, vmax, pmax, a, b, c, d, e)
total_alpha += alpha1 + alpha2
total_beta += beta2
if total_abo >= recettes:
return 0.0
if total_alpha == 0:
return 0.0
return (recettes - total_abo - total_beta) / total_alpha
def compute_tariff(
households: list[HouseholdData],
recettes: float,
abop: float,
abos: float,
vinf: float,
vmax: float,
pmax: float,
a: float,
b: float,
c: float,
d: float,
e: float,
nbpts: int = 200,
) -> TariffResult:
"""
Full tariff computation: p0, price curves, and per-household bills.
"""
p0 = compute_p0(households, recettes, abop, abos, vinf, vmax, pmax, a, b, c, d, e)
# Generate curve points
tt = np.linspace(0, 1 - 1e-6, nbpts)
# Tier 1 volumes and prices
vv1 = vinf * ((1 - 3 * b) * tt**3 + 3 * b * tt**2)
prix_m3_1 = p0 * ((3 * a - 2) * tt**3 + (-6 * a + 3) * tt**2 + 3 * a * tt)
# Tier 2 volumes and prices
vv2 = vinf + (vmax - vinf) * (
(3 * (c + d - c * d) - 2) * tt**3
+ 3 * (1 - 2 * c - d + c * d) * tt**2
+ 3 * c * tt
)
prix_m3_2 = p0 + (pmax - p0) * ((1 - 3 * e) * tt**3 + 3 * e * tt**2)
vv = np.concatenate([vv1, vv2])
prix_m3 = np.concatenate([prix_m3_1, prix_m3_2])
# Compute full bills (integral) for each curve point
alpha1_arr = np.zeros(len(vv))
alpha2_arr = np.zeros(len(vv))
beta2_arr = np.zeros(len(vv))
for iv, v in enumerate(vv):
alpha1_arr[iv], alpha2_arr[iv], beta2_arr[iv] = compute_integrals(
v, vinf, vmax, pmax, a, b, c, d, e
)
bills_rp = abop + (alpha1_arr + alpha2_arr) * p0 + beta2_arr
bills_rs = abos + (alpha1_arr + alpha2_arr) * p0 + beta2_arr
# Per-household projected bills
household_bills = []
for h in households:
vol = max(h.volume_m3, 1e-5)
abo = abos if h.status == "RS" else abop
a1, a2, b2 = compute_integrals(vol, vinf, vmax, pmax, a, b, c, d, e)
household_bills.append(abo + (a1 + a2) * p0 + b2)
return TariffResult(
p0=p0,
curve_volumes=vv.tolist(),
curve_prices_m3=prix_m3.tolist(),
curve_bills_rp=bills_rp.tolist(),
curve_bills_rs=bills_rs.tolist(),
household_bills=household_bills,
)
def compute_impacts(
households: list[HouseholdData],
recettes: float,
abop: float,
abos: float,
vinf: float,
vmax: float,
pmax: float,
a: float,
b: float,
c: float,
d: float,
e: float,
reference_volumes: list[float] | None = None,
) -> tuple[float, list[ImpactRow]]:
"""
Compute p0 and price impacts for reference volume levels.
Returns (p0, list of ImpactRow).
"""
if reference_volumes is None:
reference_volumes = [30, 60, 90, 150, 300]
p0 = compute_p0(households, recettes, abop, abos, vinf, vmax, pmax, a, b, c, d, e)
# Compute average 2018 price per m³ for a rough "old price" baseline
total_vol = sum(max(h.volume_m3, 1e-5) for h in households)
total_abo_old = sum(abos if h.status == "RS" else abop for h in households)
old_p_m3 = (recettes - total_abo_old) / total_vol if total_vol > 0 else 0
impacts = []
for vol in reference_volumes:
# Old price (linear model)
old_price_rp = abop + old_p_m3 * vol
# New price
a1, a2, b2 = compute_integrals(vol, vinf, vmax, pmax, a, b, c, d, e)
new_price_rp = abop + (a1 + a2) * p0 + b2
new_price_rs = abos + (a1 + a2) * p0 + b2
impacts.append(ImpactRow(
volume=vol,
old_price=old_price_rp,
new_price_rp=new_price_rp,
new_price_rs=new_price_rs,
))
return p0, impacts

42
backend/app/main.py Normal file
View File

@@ -0,0 +1,42 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.database import init_db
from app.routers import auth, communes, content, tariff, votes, households
@asynccontextmanager
async def lifespan(app: FastAPI):
await init_db()
yield
app = FastAPI(
title=settings.APP_NAME,
description="Outil de démocratie participative pour la tarification de l'eau",
version="0.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
app.include_router(communes.router, prefix="/api/v1/communes", tags=["communes"])
app.include_router(tariff.router, prefix="/api/v1/tariff", tags=["tariff"])
app.include_router(votes.router, prefix="/api/v1", tags=["votes"])
app.include_router(households.router, prefix="/api/v1", tags=["households"])
app.include_router(content.router, prefix="/api/v1/communes", tags=["content"])
@app.get("/api/health")
async def health():
return {"status": "ok"}

View File

@@ -0,0 +1,9 @@
from app.models.models import (
Commune, TariffParams, Household, AdminUser, Vote, CommuneContent,
admin_commune_table,
)
__all__ = [
"Commune", "TariffParams", "Household", "AdminUser", "Vote",
"CommuneContent", "admin_commune_table",
]

View File

@@ -0,0 +1,120 @@
"""SQLAlchemy ORM models."""
from datetime import datetime
from sqlalchemy import (
Column, Integer, String, Float, Boolean, DateTime, ForeignKey, Text, Table,
UniqueConstraint,
)
from sqlalchemy.orm import relationship
from app.database import Base
# Many-to-many: admin users <-> communes
admin_commune_table = Table(
"admin_commune",
Base.metadata,
Column("admin_id", Integer, ForeignKey("admin_users.id"), primary_key=True),
Column("commune_id", Integer, ForeignKey("communes.id"), primary_key=True),
)
class Commune(Base):
__tablename__ = "communes"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(200), nullable=False)
slug = Column(String(200), unique=True, nullable=False, index=True)
description = Column(Text, default="")
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
tariff_params = relationship("TariffParams", back_populates="commune", uselist=False)
households = relationship("Household", back_populates="commune")
votes = relationship("Vote", back_populates="commune")
contents = relationship("CommuneContent", back_populates="commune")
admins = relationship("AdminUser", secondary=admin_commune_table, back_populates="communes")
class TariffParams(Base):
__tablename__ = "tariff_params"
id = Column(Integer, primary_key=True, index=True)
commune_id = Column(Integer, ForeignKey("communes.id"), unique=True, nullable=False)
abop = Column(Float, default=100.0)
abos = Column(Float, default=100.0)
recettes = Column(Float, default=75000.0)
pmax = Column(Float, default=20.0)
vmax = Column(Float, default=2100.0)
commune = relationship("Commune", back_populates="tariff_params")
class Household(Base):
__tablename__ = "households"
id = Column(Integer, primary_key=True, index=True)
commune_id = Column(Integer, ForeignKey("communes.id"), nullable=False)
identifier = Column(String(200), nullable=False)
status = Column(String(10), nullable=False) # RS, RP, PRO
volume_m3 = Column(Float, nullable=False)
price_paid_eur = Column(Float, default=0.0)
auth_code = Column(String(8), unique=True, nullable=False, index=True)
has_voted = Column(Boolean, default=False)
commune = relationship("Commune", back_populates="households")
votes = relationship("Vote", back_populates="household")
__table_args__ = (
UniqueConstraint("commune_id", "identifier", name="uq_household_commune_identifier"),
)
class AdminUser(Base):
__tablename__ = "admin_users"
id = Column(Integer, primary_key=True, index=True)
email = Column(String(200), unique=True, nullable=False, index=True)
hashed_password = Column(String(200), nullable=False)
full_name = Column(String(200), default="")
role = Column(String(20), default="commune_admin") # super_admin / commune_admin
communes = relationship("Commune", secondary=admin_commune_table, back_populates="admins")
class Vote(Base):
__tablename__ = "votes"
id = Column(Integer, primary_key=True, index=True)
commune_id = Column(Integer, ForeignKey("communes.id"), nullable=False)
household_id = Column(Integer, ForeignKey("households.id"), nullable=False)
vinf = Column(Float, nullable=False)
a = Column(Float, nullable=False)
b = Column(Float, nullable=False)
c = Column(Float, nullable=False)
d = Column(Float, nullable=False)
e = Column(Float, nullable=False)
computed_p0 = Column(Float, nullable=True)
submitted_at = Column(DateTime, default=datetime.utcnow)
is_active = Column(Boolean, default=True)
commune = relationship("Commune", back_populates="votes")
household = relationship("Household", back_populates="votes")
class CommuneContent(Base):
__tablename__ = "commune_contents"
id = Column(Integer, primary_key=True, index=True)
commune_id = Column(Integer, ForeignKey("communes.id"), nullable=False)
slug = Column(String(200), nullable=False) # page identifier
title = Column(String(200), default="")
body_markdown = Column(Text, default="")
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
commune = relationship("Commune", back_populates="contents")
__table_args__ = (
UniqueConstraint("commune_id", "slug", name="uq_content_commune_slug"),
)

View File

View File

@@ -0,0 +1,84 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models import AdminUser, Household, Commune
from app.schemas import AdminLogin, CitizenVerify, Token, AdminUserCreate, AdminUserOut
from app.services.auth_service import (
verify_password, create_admin_token, create_citizen_token,
hash_password, require_super_admin,
)
router = APIRouter()
@router.post("/admin/login", response_model=Token)
async def admin_login(data: AdminLogin, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(AdminUser)
.options(selectinload(AdminUser.communes))
.where(AdminUser.email == data.email)
)
admin = result.scalar_one_or_none()
if not admin or not verify_password(data.password, admin.hashed_password):
raise HTTPException(status_code=401, detail="Identifiants invalides")
# For commune_admin, include their first commune slug
commune_slug = None
if admin.communes:
commune_slug = admin.communes[0].slug
return Token(
access_token=create_admin_token(admin),
role=admin.role,
commune_slug=commune_slug,
)
@router.post("/citizen/verify", response_model=Token)
async def citizen_verify(data: CitizenVerify, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(Household)
.join(Commune)
.where(Commune.slug == data.commune_slug, Household.auth_code == data.auth_code)
)
household = result.scalar_one_or_none()
if not household:
raise HTTPException(status_code=401, detail="Code invalide ou commune introuvable")
return Token(
access_token=create_citizen_token(household, data.commune_slug),
role="citizen",
commune_slug=data.commune_slug,
)
@router.post("/admin/create", response_model=AdminUserOut)
async def create_admin(
data: AdminUserCreate,
db: AsyncSession = Depends(get_db),
current: AdminUser = Depends(require_super_admin),
):
existing = await db.execute(select(AdminUser).where(AdminUser.email == data.email))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Email déjà utilisé")
admin = AdminUser(
email=data.email,
hashed_password=hash_password(data.password),
full_name=data.full_name,
role=data.role,
)
if data.commune_slugs:
for slug in data.commune_slugs:
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if commune:
admin.communes.append(commune)
db.add(admin)
await db.commit()
await db.refresh(admin)
return admin

View File

@@ -0,0 +1,128 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete
from app.database import get_db
from app.models import Commune, TariffParams, Household, Vote, CommuneContent, AdminUser, admin_commune_table
from app.schemas import (
CommuneCreate, CommuneUpdate, CommuneOut,
TariffParamsUpdate, TariffParamsOut,
)
from app.services.auth_service import get_current_admin, require_super_admin
router = APIRouter()
@router.get("/", response_model=list[CommuneOut])
async def list_communes(db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Commune).where(Commune.is_active == True))
return result.scalars().all()
@router.post("/", response_model=CommuneOut)
async def create_commune(
data: CommuneCreate,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(require_super_admin),
):
existing = await db.execute(select(Commune).where(Commune.slug == data.slug))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Slug déjà utilisé")
commune = Commune(name=data.name, slug=data.slug, description=data.description)
db.add(commune)
await db.flush()
params = TariffParams(commune_id=commune.id)
db.add(params)
await db.commit()
await db.refresh(commune)
return commune
@router.get("/{slug}", response_model=CommuneOut)
async def get_commune(slug: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
return commune
@router.put("/{slug}", response_model=CommuneOut)
async def update_commune(
slug: str,
data: CommuneUpdate,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
if data.name is not None:
commune.name = data.name
if data.description is not None:
commune.description = data.description
if data.is_active is not None:
commune.is_active = data.is_active
await db.commit()
await db.refresh(commune)
return commune
@router.delete("/{slug}")
async def delete_commune(
slug: str,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(require_super_admin),
):
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
# Delete related data in order
await db.execute(delete(Vote).where(Vote.commune_id == commune.id))
await db.execute(delete(Household).where(Household.commune_id == commune.id))
await db.execute(delete(TariffParams).where(TariffParams.commune_id == commune.id))
await db.execute(delete(CommuneContent).where(CommuneContent.commune_id == commune.id))
await db.execute(delete(admin_commune_table).where(admin_commune_table.c.commune_id == commune.id))
await db.delete(commune)
await db.commit()
return {"detail": f"Commune '{slug}' supprimée"}
@router.get("/{slug}/params", response_model=TariffParamsOut)
async def get_params(slug: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(TariffParams).join(Commune).where(Commune.slug == slug)
)
params = result.scalar_one_or_none()
if not params:
raise HTTPException(status_code=404, detail="Paramètres introuvables")
return params
@router.put("/{slug}/params", response_model=TariffParamsOut)
async def update_params(
slug: str,
data: TariffParamsUpdate,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
result = await db.execute(
select(TariffParams).join(Commune).where(Commune.slug == slug)
)
params = result.scalar_one_or_none()
if not params:
raise HTTPException(status_code=404, detail="Paramètres introuvables")
for field, value in data.model_dump(exclude_unset=True).items():
setattr(params, field, value)
await db.commit()
await db.refresh(params)
return params

View File

@@ -0,0 +1,102 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models import Commune, CommuneContent, AdminUser
from app.schemas import ContentUpdate, ContentOut
from app.services.auth_service import get_current_admin
router = APIRouter()
async def _get_commune(slug: str, db: AsyncSession) -> Commune:
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
return commune
@router.get("/{slug}/content", response_model=list[ContentOut])
async def list_content(slug: str, db: AsyncSession = Depends(get_db)):
"""List all content pages for a commune (public)."""
commune = await _get_commune(slug, db)
result = await db.execute(
select(CommuneContent)
.where(CommuneContent.commune_id == commune.id)
.order_by(CommuneContent.slug)
)
return result.scalars().all()
@router.get("/{slug}/content/{page_slug}", response_model=ContentOut)
async def get_content(slug: str, page_slug: str, db: AsyncSession = Depends(get_db)):
"""Get a specific content page (public)."""
commune = await _get_commune(slug, db)
result = await db.execute(
select(CommuneContent)
.where(CommuneContent.commune_id == commune.id)
.where(CommuneContent.slug == page_slug)
)
content = result.scalar_one_or_none()
if not content:
raise HTTPException(status_code=404, detail="Page introuvable")
return content
@router.put("/{slug}/content/{page_slug}", response_model=ContentOut)
async def upsert_content(
slug: str,
page_slug: str,
data: ContentUpdate,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
"""Create or update a content page (admin only)."""
commune = await _get_commune(slug, db)
result = await db.execute(
select(CommuneContent)
.where(CommuneContent.commune_id == commune.id)
.where(CommuneContent.slug == page_slug)
)
content = result.scalar_one_or_none()
if content:
content.title = data.title
content.body_markdown = data.body_markdown
else:
content = CommuneContent(
commune_id=commune.id,
slug=page_slug,
title=data.title,
body_markdown=data.body_markdown,
)
db.add(content)
await db.commit()
await db.refresh(content)
return content
@router.delete("/{slug}/content/{page_slug}")
async def delete_content(
slug: str,
page_slug: str,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
"""Delete a content page (admin only)."""
commune = await _get_commune(slug, db)
result = await db.execute(
select(CommuneContent)
.where(CommuneContent.commune_id == commune.id)
.where(CommuneContent.slug == page_slug)
)
content = result.scalar_one_or_none()
if not content:
raise HTTPException(status_code=404, detail="Page introuvable")
await db.delete(content)
await db.commit()
return {"detail": f"Page '{page_slug}' supprimée"}

View File

@@ -0,0 +1,117 @@
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
import io
import numpy as np
from app.database import get_db
from app.models import Commune, Household, AdminUser
from app.schemas import HouseholdOut, HouseholdStats, ImportPreview, ImportResult
from app.services.auth_service import get_current_admin
from app.services.import_service import parse_import_file, import_households, generate_template_csv
router = APIRouter()
@router.get("/communes/{slug}/households/template")
async def download_template():
content = generate_template_csv()
return StreamingResponse(
io.BytesIO(content),
media_type="text/csv",
headers={"Content-Disposition": "attachment; filename=template_foyers.csv"},
)
@router.get("/communes/{slug}/households/stats", response_model=HouseholdStats)
async def household_stats(slug: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
hh_result = await db.execute(
select(Household).where(Household.commune_id == commune.id)
)
households = hh_result.scalars().all()
if not households:
return HouseholdStats(
total=0, rs_count=0, rp_count=0, pro_count=0,
total_volume=0, avg_volume=0, median_volume=0, voted_count=0,
)
volumes = [h.volume_m3 for h in households]
return HouseholdStats(
total=len(households),
rs_count=sum(1 for h in households if h.status == "RS"),
rp_count=sum(1 for h in households if h.status == "RP"),
pro_count=sum(1 for h in households if h.status == "PRO"),
total_volume=sum(volumes),
avg_volume=float(np.mean(volumes)),
median_volume=float(np.median(volumes)),
voted_count=sum(1 for h in households if h.has_voted),
)
@router.post("/communes/{slug}/households/import/preview", response_model=ImportPreview)
async def preview_import(
slug: str,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
content = await file.read()
df, errors = parse_import_file(content, file.filename)
if df is None:
return ImportPreview(valid_rows=0, errors=errors, sample=[])
valid_rows = len(df) - len(errors)
sample = df.head(5).to_dict(orient="records")
return ImportPreview(valid_rows=valid_rows, errors=errors, sample=sample)
@router.post("/communes/{slug}/households/import", response_model=ImportResult)
async def do_import(
slug: str,
file: UploadFile = File(...),
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
content = await file.read()
df, parse_errors = parse_import_file(content, file.filename)
if df is None or parse_errors:
raise HTTPException(status_code=400, detail={"errors": parse_errors})
created, import_errors = await import_households(db, commune.id, df)
return ImportResult(created=created, errors=import_errors)
@router.get("/communes/{slug}/households", response_model=list[HouseholdOut])
async def list_households(
slug: str,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
hh_result = await db.execute(
select(Household).where(Household.commune_id == commune.id)
)
return hh_result.scalars().all()

View File

@@ -0,0 +1,96 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.database import get_db
from app.models import Commune, TariffParams, Household
from app.schemas import TariffComputeRequest, TariffComputeResponse, ImpactRowOut
from app.engine.pricing import HouseholdData, compute_tariff, compute_impacts
router = APIRouter()
async def _load_commune_data(
slug: str, db: AsyncSession
) -> tuple[list[HouseholdData], TariffParams]:
"""Load households and tariff params for a commune."""
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
params_result = await db.execute(
select(TariffParams).where(TariffParams.commune_id == commune.id)
)
params = params_result.scalar_one_or_none()
if not params:
raise HTTPException(status_code=404, detail="Paramètres tarifs manquants")
hh_result = await db.execute(
select(Household).where(Household.commune_id == commune.id)
)
households_db = hh_result.scalars().all()
if not households_db:
raise HTTPException(status_code=400, detail="Aucun foyer importé pour cette commune")
households = [
HouseholdData(
volume_m3=h.volume_m3,
status=h.status,
price_paid_eur=h.price_paid_eur,
)
for h in households_db
]
return households, params
@router.post("/compute", response_model=TariffComputeResponse)
async def compute(data: TariffComputeRequest, db: AsyncSession = Depends(get_db)):
households, params = await _load_commune_data(data.commune_slug, db)
result = compute_tariff(
households,
recettes=params.recettes,
abop=params.abop,
abos=params.abos,
vinf=data.vinf,
vmax=params.vmax,
pmax=params.pmax,
a=data.a,
b=data.b,
c=data.c,
d=data.d,
e=data.e,
)
p0, impacts = compute_impacts(
households,
recettes=params.recettes,
abop=params.abop,
abos=params.abos,
vinf=data.vinf,
vmax=params.vmax,
pmax=params.pmax,
a=data.a,
b=data.b,
c=data.c,
d=data.d,
e=data.e,
)
return TariffComputeResponse(
p0=result.p0,
curve_volumes=result.curve_volumes,
curve_prices_m3=result.curve_prices_m3,
curve_bills_rp=result.curve_bills_rp,
curve_bills_rs=result.curve_bills_rs,
impacts=[
ImpactRowOut(
volume=imp.volume,
old_price=imp.old_price,
new_price_rp=imp.new_price_rp,
new_price_rs=imp.new_price_rs,
)
for imp in impacts
],
)

View File

@@ -0,0 +1,287 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from app.database import get_db
from app.models import Commune, Household, Vote, TariffParams, AdminUser
from app.schemas import VoteCreate, VoteOut, MedianOut, TariffComputeResponse, ImpactRowOut
from app.services.auth_service import get_current_citizen, get_current_admin
from app.engine.pricing import HouseholdData, compute_p0, compute_tariff, compute_impacts
from app.engine.current_model import compute_linear_tariff
from app.engine.median import VoteParams, compute_median
router = APIRouter()
async def _get_commune_by_slug(slug: str, db: AsyncSession) -> Commune:
result = await db.execute(select(Commune).where(Commune.slug == slug))
commune = result.scalar_one_or_none()
if not commune:
raise HTTPException(status_code=404, detail="Commune introuvable")
return commune
async def _load_commune_context(commune_id: int, db: AsyncSession):
"""Load tariff params and households for a commune."""
params_result = await db.execute(
select(TariffParams).where(TariffParams.commune_id == commune_id)
)
params = params_result.scalar_one_or_none()
hh_result = await db.execute(
select(Household).where(Household.commune_id == commune_id)
)
households_db = hh_result.scalars().all()
households = [
HouseholdData(volume_m3=h.volume_m3, status=h.status, price_paid_eur=h.price_paid_eur)
for h in households_db
]
return params, households
# ── Public endpoint: current median curve for citizens ──
@router.get("/communes/{slug}/votes/current")
async def current_curve(slug: str, db: AsyncSession = Depends(get_db)):
"""
Public endpoint: returns the current median curve + baseline linear model.
No auth required — this is what citizens see when they visit a commune page.
Always returns the baseline linear model.
Returns the median Bézier curve only if votes exist.
"""
commune = await _get_commune_by_slug(slug, db)
params, households = await _load_commune_context(commune.id, db)
if not params or not households:
return {"has_votes": False, "vote_count": 0}
# Always compute the baseline linear model
baseline = compute_linear_tariff(
households, recettes=params.recettes,
abop=params.abop, abos=params.abos, vmax=params.vmax,
)
baseline_data = {
"p0_linear": baseline.p0,
"baseline_volumes": baseline.curve_volumes,
"baseline_bills_rp": baseline.curve_bills_rp,
"baseline_bills_rs": baseline.curve_bills_rs,
"baseline_price_m3_rp": baseline.curve_price_m3_rp,
"baseline_price_m3_rs": baseline.curve_price_m3_rs,
}
# Tariff params for the frontend
tariff_params = {
"recettes": params.recettes,
"abop": params.abop,
"abos": params.abos,
"pmax": params.pmax,
"vmax": params.vmax,
}
# Get active votes
result = await db.execute(
select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True)
)
votes = result.scalars().all()
if not votes:
# Return default Bézier curve (a=b=c=d=e=0.5, vinf=vmax/2)
default_vinf = params.vmax / 2
default_tariff = compute_tariff(
households,
recettes=params.recettes, abop=params.abop, abos=params.abos,
vinf=default_vinf, vmax=params.vmax, pmax=params.pmax,
a=0.5, b=0.5, c=0.5, d=0.5, e=0.5,
)
_, default_impacts = compute_impacts(
households,
recettes=params.recettes, abop=params.abop, abos=params.abos,
vinf=default_vinf, vmax=params.vmax, pmax=params.pmax,
a=0.5, b=0.5, c=0.5, d=0.5, e=0.5,
)
return {
"has_votes": False,
"vote_count": 0,
"params": tariff_params,
"median": {
"vinf": default_vinf, "a": 0.5, "b": 0.5,
"c": 0.5, "d": 0.5, "e": 0.5,
},
"p0": default_tariff.p0,
"curve_volumes": default_tariff.curve_volumes,
"curve_prices_m3": default_tariff.curve_prices_m3,
"curve_bills_rp": default_tariff.curve_bills_rp,
"curve_bills_rs": default_tariff.curve_bills_rs,
"impacts": [
{"volume": imp.volume, "old_price": imp.old_price,
"new_price_rp": imp.new_price_rp, "new_price_rs": imp.new_price_rs}
for imp in default_impacts
],
**baseline_data,
}
# Compute median
vote_params = [
VoteParams(vinf=v.vinf, a=v.a, b=v.b, c=v.c, d=v.d, e=v.e)
for v in votes
]
median = compute_median(vote_params)
# Compute full tariff for the median
tariff = compute_tariff(
households,
recettes=params.recettes, abop=params.abop, abos=params.abos,
vinf=median.vinf, vmax=params.vmax, pmax=params.pmax,
a=median.a, b=median.b, c=median.c, d=median.d, e=median.e,
)
_, impacts = compute_impacts(
households,
recettes=params.recettes, abop=params.abop, abos=params.abos,
vinf=median.vinf, vmax=params.vmax, pmax=params.pmax,
a=median.a, b=median.b, c=median.c, d=median.d, e=median.e,
)
return {
"has_votes": True,
"vote_count": len(votes),
"params": tariff_params,
"median": {
"vinf": median.vinf,
"a": median.a,
"b": median.b,
"c": median.c,
"d": median.d,
"e": median.e,
},
"p0": tariff.p0,
"curve_volumes": tariff.curve_volumes,
"curve_prices_m3": tariff.curve_prices_m3,
"curve_bills_rp": tariff.curve_bills_rp,
"curve_bills_rs": tariff.curve_bills_rs,
"impacts": [
{"volume": imp.volume, "old_price": imp.old_price,
"new_price_rp": imp.new_price_rp, "new_price_rs": imp.new_price_rs}
for imp in impacts
],
**baseline_data,
}
# ── Citizen: submit vote ──
@router.post("/communes/{slug}/votes", response_model=VoteOut)
async def submit_vote(
slug: str,
data: VoteCreate,
db: AsyncSession = Depends(get_db),
household: Household = Depends(get_current_citizen),
):
commune = await _get_commune_by_slug(slug, db)
if household.commune_id != commune.id:
raise HTTPException(status_code=403, detail="Accès interdit à cette commune")
# Deactivate previous votes
await db.execute(
update(Vote)
.where(Vote.household_id == household.id, Vote.is_active == True)
.values(is_active=False)
)
params, households = await _load_commune_context(commune.id, db)
computed_p0 = compute_p0(
households,
recettes=params.recettes, abop=params.abop, abos=params.abos,
vinf=data.vinf, vmax=params.vmax, pmax=params.pmax,
a=data.a, b=data.b, c=data.c, d=data.d, e=data.e,
) if params else None
vote = Vote(
commune_id=commune.id,
household_id=household.id,
vinf=data.vinf, a=data.a, b=data.b, c=data.c, d=data.d, e=data.e,
computed_p0=computed_p0,
)
db.add(vote)
household.has_voted = True
await db.commit()
await db.refresh(vote)
return vote
# ── Admin: list votes ──
@router.get("/communes/{slug}/votes", response_model=list[VoteOut])
async def list_votes(
slug: str,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
commune = await _get_commune_by_slug(slug, db)
result = await db.execute(
select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True)
)
return result.scalars().all()
# ── Admin: median ──
@router.get("/communes/{slug}/votes/median", response_model=MedianOut)
async def vote_median(
slug: str,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
commune = await _get_commune_by_slug(slug, db)
result = await db.execute(
select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True)
)
votes = result.scalars().all()
if not votes:
raise HTTPException(status_code=404, detail="Aucun vote actif")
vote_params = [
VoteParams(vinf=v.vinf, a=v.a, b=v.b, c=v.c, d=v.d, e=v.e)
for v in votes
]
median = compute_median(vote_params)
params, households = await _load_commune_context(commune.id, db)
computed_p0 = compute_p0(
households,
recettes=params.recettes, abop=params.abop, abos=params.abos,
vinf=median.vinf, vmax=params.vmax, pmax=params.pmax,
a=median.a, b=median.b, c=median.c, d=median.d, e=median.e,
) if params else 0
return MedianOut(
vinf=median.vinf, a=median.a, b=median.b, c=median.c, d=median.d, e=median.e,
computed_p0=computed_p0, vote_count=len(votes),
)
# ── Admin: overlay ──
@router.get("/communes/{slug}/votes/overlay")
async def vote_overlay(
slug: str,
db: AsyncSession = Depends(get_db),
admin: AdminUser = Depends(get_current_admin),
):
commune = await _get_commune_by_slug(slug, db)
result = await db.execute(
select(Vote).where(Vote.commune_id == commune.id, Vote.is_active == True)
)
votes = result.scalars().all()
return [
{"id": v.id, "vinf": v.vinf, "a": v.a, "b": v.b,
"c": v.c, "d": v.d, "e": v.e, "computed_p0": v.computed_p0}
for v in votes
]

View File

@@ -0,0 +1 @@
from app.schemas.schemas import * # noqa: F401, F403

View File

@@ -0,0 +1,205 @@
"""Pydantic schemas for API request/response validation."""
from datetime import datetime
from pydantic import BaseModel, Field
# ── Auth ──
class AdminLogin(BaseModel):
email: str
password: str
class CitizenVerify(BaseModel):
commune_slug: str
auth_code: str
class Token(BaseModel):
access_token: str
token_type: str = "bearer"
role: str
commune_slug: str | None = None
# ── Commune ──
class CommuneCreate(BaseModel):
name: str
slug: str
description: str = ""
class CommuneUpdate(BaseModel):
name: str | None = None
description: str | None = None
is_active: bool | None = None
class CommuneOut(BaseModel):
id: int
name: str
slug: str
description: str
is_active: bool
created_at: datetime
model_config = {"from_attributes": True}
# ── TariffParams ──
class TariffParamsUpdate(BaseModel):
abop: float | None = None
abos: float | None = None
recettes: float | None = None
pmax: float | None = None
vmax: float | None = None
class TariffParamsOut(BaseModel):
abop: float
abos: float
recettes: float
pmax: float
vmax: float
model_config = {"from_attributes": True}
# ── Household ──
class HouseholdOut(BaseModel):
id: int
identifier: str
status: str
volume_m3: float
price_paid_eur: float
auth_code: str
has_voted: bool
model_config = {"from_attributes": True}
class HouseholdStats(BaseModel):
total: int
rs_count: int
rp_count: int
pro_count: int
total_volume: float
avg_volume: float
median_volume: float
voted_count: int
class ImportPreview(BaseModel):
valid_rows: int
errors: list[str]
sample: list[dict]
class ImportResult(BaseModel):
created: int
errors: list[str]
# ── Tariff Compute ──
class TariffComputeRequest(BaseModel):
commune_slug: str
vinf: float = Field(ge=0)
a: float = Field(ge=0, le=1)
b: float = Field(ge=0, le=1)
c: float = Field(ge=0, le=1)
d: float = Field(ge=0, le=1)
e: float = Field(ge=0, le=1)
class ImpactRowOut(BaseModel):
volume: float
old_price: float
new_price_rp: float
new_price_rs: float
class TariffComputeResponse(BaseModel):
p0: float
curve_volumes: list[float]
curve_prices_m3: list[float]
curve_bills_rp: list[float]
curve_bills_rs: list[float]
impacts: list[ImpactRowOut]
# ── Vote ──
class VoteCreate(BaseModel):
vinf: float = Field(ge=0)
a: float = Field(ge=0, le=1)
b: float = Field(ge=0, le=1)
c: float = Field(ge=0, le=1)
d: float = Field(ge=0, le=1)
e: float = Field(ge=0, le=1)
class VoteOut(BaseModel):
id: int
household_id: int
vinf: float
a: float
b: float
c: float
d: float
e: float
computed_p0: float | None
submitted_at: datetime
is_active: bool
model_config = {"from_attributes": True}
class MedianOut(BaseModel):
vinf: float
a: float
b: float
c: float
d: float
e: float
computed_p0: float
vote_count: int
# ── Admin User ──
class AdminUserCreate(BaseModel):
email: str
password: str
full_name: str = ""
role: str = "commune_admin"
commune_slugs: list[str] = []
class AdminUserOut(BaseModel):
id: int
email: str
full_name: str
role: str
model_config = {"from_attributes": True}
# ── Content ──
class ContentUpdate(BaseModel):
title: str
body_markdown: str
class ContentOut(BaseModel):
slug: str
title: str
body_markdown: str
updated_at: datetime
model_config = {"from_attributes": True}

View File

View File

@@ -0,0 +1,89 @@
"""Authentication service: JWT creation/validation, password hashing."""
from datetime import datetime, timedelta
from jose import jwt, JWTError
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.config import settings
from app.database import get_db
from app.models import AdminUser, Household
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
security = HTTPBearer()
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain: str, hashed: str) -> bool:
return pwd_context.verify(plain, hashed)
def create_token(data: dict, expires_hours: int) -> str:
to_encode = data.copy()
to_encode["exp"] = datetime.utcnow() + timedelta(hours=expires_hours)
return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
def create_admin_token(admin: AdminUser) -> str:
return create_token(
{"sub": str(admin.id), "role": admin.role, "type": "admin"},
settings.ADMIN_TOKEN_EXPIRE_HOURS,
)
def create_citizen_token(household: Household, commune_slug: str) -> str:
return create_token(
{
"sub": str(household.id),
"commune_id": household.commune_id,
"commune_slug": commune_slug,
"type": "citizen",
},
settings.CITIZEN_TOKEN_EXPIRE_HOURS,
)
def decode_token(token: str) -> dict:
try:
return jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
except JWTError:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
async def get_current_admin(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> AdminUser:
payload = decode_token(credentials.credentials)
if payload.get("type") != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
admin = await db.get(AdminUser, int(payload["sub"]))
if not admin:
raise HTTPException(status_code=401, detail="Admin not found")
return admin
async def get_current_citizen(
credentials: HTTPAuthorizationCredentials = Depends(security),
db: AsyncSession = Depends(get_db),
) -> Household:
payload = decode_token(credentials.credentials)
if payload.get("type") != "citizen":
raise HTTPException(status_code=403, detail="Citizen access required")
household = await db.get(Household, int(payload["sub"]))
if not household:
raise HTTPException(status_code=401, detail="Household not found")
return household
def require_super_admin(admin: AdminUser = Depends(get_current_admin)) -> AdminUser:
if admin.role != "super_admin":
raise HTTPException(status_code=403, detail="Super admin access required")
return admin

View File

@@ -0,0 +1,143 @@
"""Service for importing household data from CSV/XLSX files."""
import io
import secrets
import string
import pandas as pd
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models import Household
# Characters without ambiguous ones (O/0/I/1/l)
SAFE_CHARS = string.ascii_uppercase.replace("O", "").replace("I", "") + string.digits.replace("0", "").replace("1", "")
def generate_auth_code(length: int = 8) -> str:
return "".join(secrets.choice(SAFE_CHARS) for _ in range(length))
VALID_STATUSES = {"RS", "RP", "PRO"}
REQUIRED_COLUMNS = {"identifier", "status", "volume_m3", "price_eur"}
def parse_import_file(file_bytes: bytes, filename: str) -> tuple[pd.DataFrame | None, list[str]]:
"""
Parse a CSV or XLSX file and validate its contents.
Returns (dataframe, errors). If errors is non-empty, dataframe may be None.
"""
errors = []
try:
if filename.endswith(".csv"):
df = pd.read_csv(io.BytesIO(file_bytes))
elif filename.endswith((".xlsx", ".xls")):
df = pd.read_excel(io.BytesIO(file_bytes))
else:
return None, ["Format non supporté. Utilisez CSV ou XLSX."]
except Exception as e:
return None, [f"Erreur de lecture du fichier: {e}"]
# Normalize column names
df.columns = [c.strip().lower().replace(" ", "_") for c in df.columns]
missing = REQUIRED_COLUMNS - set(df.columns)
if missing:
return None, [f"Colonnes manquantes: {', '.join(missing)}"]
# Validate rows
for idx, row in df.iterrows():
line = idx + 2 # Excel line number (1-indexed + header)
status = str(row["status"]).strip().upper()
if status not in VALID_STATUSES:
errors.append(f"Ligne {line}: statut '{row['status']}' invalide (attendu: RS, RP, PRO)")
try:
vol = float(row["volume_m3"])
if vol < 0:
errors.append(f"Ligne {line}: volume négatif ({vol})")
except (ValueError, TypeError):
errors.append(f"Ligne {line}: volume invalide '{row['volume_m3']}'")
price = row.get("price_eur")
if pd.notna(price):
try:
p = float(price)
if p < 0:
errors.append(f"Ligne {line}: prix négatif ({p})")
except (ValueError, TypeError):
errors.append(f"Ligne {line}: prix invalide '{price}'")
# Normalize
df["status"] = df["status"].str.strip().str.upper()
df["identifier"] = df["identifier"].astype(str).str.strip()
return df, errors
async def import_households(
db: AsyncSession,
commune_id: int,
df: pd.DataFrame,
) -> tuple[int, list[str]]:
"""
Import validated households into the database.
Returns (created_count, errors).
"""
created = 0
errors = []
# Get existing auth codes to avoid collisions
existing_codes = set()
result = await db.execute(select(Household.auth_code))
for row in result.scalars():
existing_codes.add(row)
for idx, row in df.iterrows():
identifier = str(row["identifier"]).strip()
status = str(row["status"]).strip().upper()
volume = float(row["volume_m3"])
price = float(row["price_eur"]) if pd.notna(row.get("price_eur")) else 0.0
# Check for duplicate
existing = await db.execute(
select(Household).where(
Household.commune_id == commune_id,
Household.identifier == identifier,
)
)
if existing.scalar_one_or_none():
errors.append(f"Foyer '{identifier}' existe déjà, ignoré.")
continue
# Generate unique auth code
code = generate_auth_code()
while code in existing_codes:
code = generate_auth_code()
existing_codes.add(code)
household = Household(
commune_id=commune_id,
identifier=identifier,
status=status,
volume_m3=volume,
price_paid_eur=price,
auth_code=code,
)
db.add(household)
created += 1
await db.commit()
return created, errors
def generate_template_csv() -> bytes:
"""Generate a template CSV file for household import."""
content = "identifier,status,volume_m3,price_eur\n"
content += "DUPONT Jean,RS,85.5,189.50\n"
content += "MARTIN Pierre,RP,120.0,245.00\n"
content += "SARL Boulangerie,PRO,350.0,\n"
return content.encode("utf-8")

18
backend/requirements.txt Normal file
View File

@@ -0,0 +1,18 @@
fastapi==0.115.6
uvicorn[standard]==0.34.0
sqlalchemy==2.0.36
alembic==1.14.0
pydantic==2.10.3
pydantic-settings==2.7.0
python-jose[cryptography]==3.3.0
passlib[bcrypt]==1.7.4
bcrypt==4.1.3
python-multipart==0.0.18
numpy==1.26.4
openpyxl==3.1.5
xlrd==2.0.1
pandas==2.2.3
aiosqlite==0.20.0
pytest==8.3.4
pytest-asyncio==0.24.0
httpx==0.28.1

103
backend/seed.py Normal file
View File

@@ -0,0 +1,103 @@
"""Seed the database with Saoû data from Eau2018.xls."""
import asyncio
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
import xlrd
from sqlalchemy import select
from app.database import engine, async_session, init_db
from app.models import Commune, TariffParams, Household, AdminUser
from app.services.auth_service import hash_password
from app.services.import_service import generate_auth_code
XLS_PATH = os.path.join(os.path.dirname(__file__), "..", "Eau2018.xls")
async def seed():
await init_db()
async with async_session() as db:
# Check if already seeded
result = await db.execute(select(Commune).where(Commune.slug == "saou"))
if result.scalar_one_or_none():
print("Saoû already seeded.")
return
# Create commune
commune = Commune(
name="Saoû",
slug="saou",
description="Commune de Saoû - Tarification progressive de l'eau",
)
db.add(commune)
await db.flush()
# Create tariff params
params = TariffParams(
commune_id=commune.id,
abop=100,
abos=100,
recettes=75000,
pmax=20,
vmax=2100,
)
db.add(params)
# Create super admin (manages all communes)
super_admin = AdminUser(
email="superadmin@sejeteralo.fr",
hashed_password=hash_password("superadmin"),
full_name="Super Admin",
role="super_admin",
)
db.add(super_admin)
# Create commune admin for Saoû (manages only this commune)
commune_admin = AdminUser(
email="saou@sejeteralo.fr",
hashed_password=hash_password("saou2024"),
full_name="Admin Saoû",
role="commune_admin",
)
commune_admin.communes.append(commune)
db.add(commune_admin)
# Import households from Eau2018.xls
book = xlrd.open_workbook(XLS_PATH)
sheet = book.sheet_by_name("CALCULS")
nb_hab = 363
existing_codes = set()
for r in range(1, nb_hab + 1):
name = sheet.cell_value(r, 0)
status = sheet.cell_value(r, 3)
volume = sheet.cell_value(r, 4)
price = sheet.cell_value(r, 33)
code = generate_auth_code()
while code in existing_codes:
code = generate_auth_code()
existing_codes.add(code)
household = Household(
commune_id=commune.id,
identifier=str(name).strip(),
status=str(status).strip().upper(),
volume_m3=float(volume),
price_paid_eur=float(price) if price else 0.0,
auth_code=code,
)
db.add(household)
await db.commit()
print(f"Seeded: commune 'saou', {nb_hab} households")
print(f" Super admin: superadmin@sejeteralo.fr / superadmin")
print(f" Commune admin Saoû: saou@sejeteralo.fr / saou2024")
if __name__ == "__main__":
asyncio.run(seed())

View File

View File

@@ -0,0 +1,280 @@
"""
Tests for the extracted math engine.
Validates that the engine produces identical results to the original eau.py
using the Saoû data (Eau2018.xls).
"""
import sys
import os
import numpy as np
import pytest
import xlrd
# Add backend to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from app.engine.integrals import compute_integrals
from app.engine.pricing import HouseholdData, compute_p0, compute_tariff
from app.engine.current_model import compute_linear_tariff
from app.engine.median import VoteParams, compute_median
# Path to the Excel file
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data")
XLS_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "Eau2018.xls")
def load_saou_households() -> list[HouseholdData]:
"""Load household data from Eau2018.xls exactly as eau.py does."""
book = xlrd.open_workbook(XLS_PATH)
sheet = book.sheet_by_name("CALCULS")
nb_hab = 363
households = []
for r in range(1, nb_hab + 1):
vol = sheet.cell_value(r, 4)
status = sheet.cell_value(r, 3)
prix = sheet.cell_value(r, 33)
households.append(HouseholdData(
volume_m3=vol,
status=status,
price_paid_eur=prix,
))
return households
# Reference original eau.py computeIntegrals for comparison
def original_compute_integrals(vv, vinf, vmax, pmax, a, b, c, d, e):
"""Direct port of eau.py computeIntegrals for validation."""
if vv <= vinf:
if vv == 0:
T = 0.0
elif vv == vinf:
T = 1.0
else:
p = [1 - 3 * b, 3 * b, 0, -vv / vinf]
roots = np.roots(p)
roots = np.unique(roots)
roots2 = np.real(roots[np.isreal(roots)])
mask = (roots2 <= 1.0) & (roots2 >= 0.0)
T = float(roots2[mask])
alpha1 = 3 * vinf * (
T**6 / 6 * (-9 * a * b + 3 * a + 6 * b - 2)
+ T**5 / 5 * (24 * a * b - 6 * a - 13 * b + 3)
+ 3 * T**4 / 4 * (-7 * a * b + a + 2 * b)
+ T**3 / 3 * 6 * a * b
)
return alpha1, 0, 0
else:
alpha1 = 3 * vinf * (
1 / 6 * (-9 * a * b + 3 * a + 6 * b - 2)
+ 1 / 5 * (24 * a * b - 6 * a - 13 * b + 3)
+ 3 / 4 * (-7 * a * b + a + 2 * b)
+ 1 / 3 * 6 * a * b
)
wmax = vmax - vinf
if vv == vinf:
T = 0.0
elif vv == vmax:
T = 1.0
else:
p = [3 * (c + d - c * d) - 2, 3 * (1 - 2 * c - d + c * d), 3 * c, -(vv - vinf) / wmax]
roots = np.roots(p)
roots = np.unique(roots)
roots2 = np.real(roots[np.isreal(roots)])
mask = (roots2 <= 1.0) & (roots2 >= 0.0)
T = float(np.real(roots2[mask]))
uu = (
(-3 * c * d + 9 * e * c * d + 3 * c - 9 * e * c + 3 * d - 9 * e * d + 6 * e - 2) * T**6 / 6
+ (2 * c * d - 15 * e * c * d - 4 * c + 21 * e * c - 2 * d + 15 * e * d - 12 * e + 2) * T**5 / 5
+ (6 * e * c * d + c - 15 * e * c - 6 * e * d + 6 * e) * T**4 / 4
+ (3 * e * c) * T**3 / 3
)
alpha2 = vv - vinf - 3 * uu * wmax
beta2 = 3 * pmax * wmax * uu
return alpha1, alpha2, beta2
class TestIntegrals:
"""Test the integral computation against the original."""
def test_tier1_zero_volume(self):
a1, a2, b2 = compute_integrals(0, 1050, 2100, 20, 0.5, 0.5, 0.5, 0.5, 0.5)
assert a1 == 0.0
assert a2 == 0.0
assert b2 == 0.0
def test_tier1_at_vinf(self):
a1, a2, b2 = compute_integrals(1050, 1050, 2100, 20, 0.5, 0.5, 0.5, 0.5, 0.5)
oa1, oa2, ob2 = original_compute_integrals(1050, 1050, 2100, 20, 0.5, 0.5, 0.5, 0.5, 0.5)
assert abs(a1 - oa1) < 1e-10
assert a2 == 0.0
assert b2 == 0.0
def test_tier2_at_vmax(self):
a1, a2, b2 = compute_integrals(2100, 1050, 2100, 20, 0.5, 0.5, 0.5, 0.5, 0.5)
oa1, oa2, ob2 = original_compute_integrals(2100, 1050, 2100, 20, 0.5, 0.5, 0.5, 0.5, 0.5)
assert abs(a1 - oa1) < 1e-10
assert abs(a2 - oa2) < 1e-6
assert abs(b2 - ob2) < 1e-6
def test_various_volumes_match_original(self):
"""Test multiple volumes with various parameter sets."""
params_sets = [
(0.5, 0.5, 0.5, 0.5, 0.5),
(0.25, 0.75, 0.3, 0.6, 0.8),
(0.1, 0.1, 0.9, 0.9, 0.1),
(0.9, 0.9, 0.1, 0.1, 0.9),
]
volumes = [0, 10, 50, 100, 300, 500, 1000, 1050, 1051, 1500, 2000, 2100]
for a, b, c, d, e in params_sets:
for vol in volumes:
vinf, vmax, pmax = 1050, 2100, 20
result = compute_integrals(vol, vinf, vmax, pmax, a, b, c, d, e)
expected = original_compute_integrals(vol, vinf, vmax, pmax, a, b, c, d, e)
for i in range(3):
assert abs(result[i] - expected[i]) < 1e-6, (
f"Mismatch at vol={vol}, params=({a},{b},{c},{d},{e}), "
f"component={i}: got {result[i]}, expected {expected[i]}"
)
class TestPricing:
"""Test the pricing computation with Saoû data."""
@pytest.fixture
def saou_households(self):
return load_saou_households()
def test_saou_data_loaded(self, saou_households):
assert len(saou_households) == 363
def test_p0_default_params(self, saou_households):
"""Test p0 with default slider values from eau.py mainFunction."""
# Default values from eau.py lines 54-62
p0 = compute_p0(
saou_households,
recettes=75000, # recettesArray[25]
abop=100, # abopArray[100]
abos=100, # abosArray[100]
vinf=1050, # vinfArray[vmax/2]
vmax=2100,
pmax=20,
a=0.5, # aArray[25]
b=0.5, # bArray[25]
c=0.5, # cArray[25]
d=0.5, # dArray[25]
e=0.5, # eArray[25]
)
# Compute the same p0 using original algorithm
volumes = np.array([max(h.volume_m3, 1e-5) for h in saou_households])
statuses = np.array([h.status for h in saou_households])
abo = 100 * np.ones(363)
abo[statuses == "RS"] = 100
alpha1_arr = np.zeros(363)
alpha2_arr = np.zeros(363)
beta2_arr = np.zeros(363)
for ih in range(363):
alpha1_arr[ih], alpha2_arr[ih], beta2_arr[ih] = original_compute_integrals(
volumes[ih], 1050, 2100, 20, 0.5, 0.5, 0.5, 0.5, 0.5
)
expected_p0 = (75000 - np.sum(beta2_arr + abo)) / np.sum(alpha1_arr + alpha2_arr)
assert abs(p0 - expected_p0) < 1e-6, f"p0={p0}, expected={expected_p0}"
assert p0 > 0, "p0 should be positive"
def test_p0_various_params(self, saou_households):
"""Test p0 with various parameter sets."""
param_sets = [
(75000, 100, 100, 1050, 0.5, 0.5, 0.5, 0.5, 0.5),
(60000, 80, 80, 800, 0.3, 0.7, 0.4, 0.6, 0.2),
(90000, 120, 90, 1200, 0.8, 0.2, 0.6, 0.3, 0.7),
]
for recettes, abop, abos, vinf, a, b, c, d, e in param_sets:
p0 = compute_p0(
saou_households, recettes, abop, abos, vinf, 2100, 20, a, b, c, d, e
)
# Verify: total bills should equal recettes
total = 0
for h in saou_households:
vol = max(h.volume_m3, 1e-5)
abo_val = abos if h.status == "RS" else abop
a1, a2, b2 = compute_integrals(vol, vinf, 2100, 20, a, b, c, d, e)
total += abo_val + (a1 + a2) * p0 + b2
assert abs(total - recettes) < 0.01, (
f"Revenue mismatch: got {total:.2f}, expected {recettes}. "
f"Params: recettes={recettes}, abop={abop}, abos={abos}, vinf={vinf}, "
f"a={a}, b={b}, c={c}, d={d}, e={e}"
)
def test_full_tariff_computation(self, saou_households):
result = compute_tariff(
saou_households,
recettes=75000,
abop=100,
abos=100,
vinf=1050,
vmax=2100,
pmax=20,
a=0.5,
b=0.5,
c=0.5,
d=0.5,
e=0.5,
)
assert result.p0 > 0
assert len(result.curve_volumes) == 400 # 200 * 2 tiers
assert len(result.household_bills) == 363
class TestLinearModel:
"""Test the linear (current) pricing model."""
@pytest.fixture
def saou_households(self):
return load_saou_households()
def test_linear_p0(self, saou_households):
result = compute_linear_tariff(saou_households, recettes=75000, abop=100, abos=100)
# p0 = (recettes - sum_abo) / sum_volume
volumes = [max(h.volume_m3, 1e-5) for h in saou_households]
total_vol = sum(volumes)
expected_p0 = (75000 - 363 * 100) / total_vol # all RS have same abo in this case
assert abs(result.p0 - expected_p0) < 1e-6
class TestMedian:
"""Test the median computation."""
def test_single_vote(self):
votes = [VoteParams(vinf=1050, a=0.5, b=0.5, c=0.5, d=0.5, e=0.5)]
m = compute_median(votes)
assert m.vinf == 1050
assert m.a == 0.5
def test_odd_votes(self):
votes = [
VoteParams(vinf=800, a=0.3, b=0.2, c=0.4, d=0.5, e=0.6),
VoteParams(vinf=1000, a=0.5, b=0.5, c=0.5, d=0.5, e=0.5),
VoteParams(vinf=1200, a=0.7, b=0.8, c=0.6, d=0.5, e=0.4),
]
m = compute_median(votes)
assert m.vinf == 1000
assert m.a == 0.5
def test_even_votes(self):
votes = [
VoteParams(vinf=800, a=0.3, b=0.2, c=0.4, d=0.5, e=0.6),
VoteParams(vinf=1200, a=0.7, b=0.8, c=0.6, d=0.5, e=0.4),
]
m = compute_median(votes)
assert m.vinf == 1000 # average of 800, 1200
assert abs(m.a - 0.5) < 1e-10
def test_empty_votes(self):
assert compute_median([]) is None