Sprint 5 : integration et production -- securite, performance, API publique, documentation

Backend: rate limiter, security headers, blockchain cache service avec RPC,
public API (7 endpoints read-only), WebSocket auth + heartbeat, DB connection
pooling, structured logging, health check DB. Frontend: API retry/timeout,
WebSocket auth + heartbeat + typed events, notifications toast, mobile hamburger
+ drawer, error boundary, offline banner, loading skeletons, dashboard enrichi.
Documentation: guides utilisateur complets (demarrage, vote, sanctuaire, FAQ 30+),
guide deploiement, politique securite. 123 tests, 155 fichiers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Yvv
2026-02-28 15:12:50 +01:00
parent 3cb1754592
commit 403b94fa2c
31 changed files with 4472 additions and 356 deletions

View File

@@ -6,24 +6,41 @@ class Settings(BaseSettings):
APP_NAME: str = "Glibredecision"
DEBUG: bool = True
# Environment
ENVIRONMENT: str = "development" # development, staging, production
LOG_LEVEL: str = "INFO"
# Database
DATABASE_URL: str = "postgresql+asyncpg://glibredecision:change-me-in-production@localhost:5432/glibredecision"
DATABASE_POOL_SIZE: int = 20
DATABASE_MAX_OVERFLOW: int = 10
# Auth
SECRET_KEY: str = "change-me-in-production-with-a-real-secret-key"
CHALLENGE_EXPIRE_SECONDS: int = 300
TOKEN_EXPIRE_HOURS: int = 24
SESSION_TTL_HOURS: int = 24
# Duniter V2 RPC
DUNITER_RPC_URL: str = "wss://gdev.p2p.legal/ws"
DUNITER_RPC_TIMEOUT_SECONDS: int = 10
# IPFS
IPFS_API_URL: str = "http://localhost:5001"
IPFS_GATEWAY_URL: str = "http://localhost:8080"
IPFS_TIMEOUT_SECONDS: int = 30
# CORS
CORS_ORIGINS: list[str] = ["http://localhost:3002"]
# Rate limiting (requests per minute)
RATE_LIMIT_DEFAULT: int = 60
RATE_LIMIT_AUTH: int = 10
RATE_LIMIT_VOTE: int = 30
# Blockchain cache
BLOCKCHAIN_CACHE_TTL_SECONDS: int = 3600
# Paths
BASE_DIR: Path = Path(__file__).resolve().parent.parent

View File

@@ -3,7 +3,14 @@ from sqlalchemy.orm import DeclarativeBase
from app.config import settings
engine = create_async_engine(settings.DATABASE_URL, echo=settings.DEBUG)
engine = create_async_engine(
settings.DATABASE_URL,
echo=settings.ENVIRONMENT == "development",
pool_size=settings.DATABASE_POOL_SIZE,
max_overflow=settings.DATABASE_MAX_OVERFLOW,
pool_pre_ping=True,
pool_recycle=3600,
)
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)

View File

@@ -1,26 +1,93 @@
import logging
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.database import init_db
from app.database import get_db, init_db
from app.middleware.rate_limiter import RateLimiterMiddleware
from app.middleware.security_headers import SecurityHeadersMiddleware
from app.routers import auth, documents, decisions, votes, mandates, protocols, sanctuary, websocket
from app.routers import public
# ── Structured logging setup ───────────────────────────────────────────────
def _setup_logging() -> None:
"""Configure structured logging based on environment.
- Production/staging: JSON-formatted log lines for log aggregation.
- Development: human-readable format with colors.
"""
log_level = getattr(logging, settings.LOG_LEVEL.upper(), logging.INFO)
if settings.ENVIRONMENT in ("production", "staging"):
# JSON formatter for structured logging
formatter = logging.Formatter(
'{"timestamp":"%(asctime)s","level":"%(levelname)s",'
'"logger":"%(name)s","message":"%(message)s"}',
datefmt="%Y-%m-%dT%H:%M:%S",
)
else:
# Human-readable format for development
formatter = logging.Formatter(
"%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.setLevel(log_level)
root_logger.handlers.clear()
root_logger.addHandler(handler)
# Reduce noise from third-party libraries
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.engine").setLevel(
logging.INFO if settings.ENVIRONMENT == "development" else logging.WARNING
)
_setup_logging()
logger = logging.getLogger(__name__)
# ── Application lifespan ───────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info(
"Demarrage %s (env=%s, log_level=%s)",
settings.APP_NAME, settings.ENVIRONMENT, settings.LOG_LEVEL,
)
await init_db()
yield
logger.info("Arret %s", settings.APP_NAME)
# ── FastAPI application ───────────────────────────────────────────────────
app = FastAPI(
title=settings.APP_NAME,
description="Plateforme de decisions collectives pour la communaute Duniter/G1",
version="0.1.0",
version="0.5.0",
lifespan=lifespan,
)
# ── Middleware stack ──────────────────────────────────────────────────────
# Middleware is applied in reverse order: last added = first executed.
# Order: SecurityHeaders -> RateLimiter -> CORS -> Application
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
@@ -29,6 +96,18 @@ app.add_middleware(
allow_headers=["*"],
)
app.add_middleware(
RateLimiterMiddleware,
rate_limit_default=settings.RATE_LIMIT_DEFAULT,
rate_limit_auth=settings.RATE_LIMIT_AUTH,
rate_limit_vote=settings.RATE_LIMIT_VOTE,
)
app.add_middleware(SecurityHeadersMiddleware)
# ── Routers ──────────────────────────────────────────────────────────────
app.include_router(auth.router, prefix="/api/v1/auth", tags=["auth"])
app.include_router(documents.router, prefix="/api/v1/documents", tags=["documents"])
app.include_router(decisions.router, prefix="/api/v1/decisions", tags=["decisions"])
@@ -37,8 +116,32 @@ app.include_router(mandates.router, prefix="/api/v1/mandates", tags=["mandates"]
app.include_router(protocols.router, prefix="/api/v1/protocols", tags=["protocols"])
app.include_router(sanctuary.router, prefix="/api/v1/sanctuary", tags=["sanctuary"])
app.include_router(websocket.router, prefix="/api/v1/ws", tags=["websocket"])
app.include_router(public.router, prefix="/api/v1/public", tags=["public"])
# ── Health check ─────────────────────────────────────────────────────────
@app.get("/api/health")
async def health():
return {"status": "ok"}
async def health(db: AsyncSession = Depends(get_db)):
"""Health check endpoint that verifies database connectivity.
Returns status "ok" with database connection info if healthy,
or status "degraded" if the database is unreachable.
"""
try:
result = await db.execute(text("SELECT 1"))
result.scalar()
db_status = "connected"
except Exception as exc:
logger.warning("Health check: base de donnees inaccessible - %s", exc)
db_status = "disconnected"
overall_status = "ok" if db_status == "connected" else "degraded"
return {
"status": overall_status,
"environment": settings.ENVIRONMENT,
"database": db_status,
"version": "0.5.0",
}

View File

View File

@@ -0,0 +1,163 @@
"""Rate limiter middleware: in-memory IP-based request throttling.
Tracks requests per IP address using a sliding window approach.
Configurable limits per endpoint category (general, auth, vote).
Returns 429 Too Many Requests with Retry-After header when exceeded.
"""
from __future__ import annotations
import asyncio
import logging
import time
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
logger = logging.getLogger(__name__)
# Cleanup interval: remove expired entries every 5 minutes
_CLEANUP_INTERVAL_SECONDS = 300
class RateLimiterMiddleware(BaseHTTPMiddleware):
"""In-memory rate limiter middleware.
Tracks request timestamps per IP and enforces configurable limits:
- General endpoints: ``rate_limit_default`` requests/min
- Auth endpoints (``/auth``): ``rate_limit_auth`` requests/min
- Vote endpoints (``/vote``): ``rate_limit_vote`` requests/min
Adds standard rate-limit headers to all responses:
- ``X-RateLimit-Limit``
- ``X-RateLimit-Remaining``
- ``X-RateLimit-Reset``
Parameters
----------
app:
The ASGI application.
rate_limit_default:
Maximum requests per minute for general endpoints.
rate_limit_auth:
Maximum requests per minute for auth endpoints.
rate_limit_vote:
Maximum requests per minute for vote endpoints.
"""
def __init__(
self,
app,
rate_limit_default: int = 60,
rate_limit_auth: int = 10,
rate_limit_vote: int = 30,
) -> None:
super().__init__(app)
self.rate_limit_default = rate_limit_default
self.rate_limit_auth = rate_limit_auth
self.rate_limit_vote = rate_limit_vote
# IP -> list of timestamps (epoch seconds)
self._requests: dict[str, list[float]] = defaultdict(list)
self._last_cleanup: float = time.time()
self._lock = asyncio.Lock()
def _get_limit_for_path(self, path: str) -> int:
"""Return the rate limit applicable to the given request path."""
if "/auth" in path:
return self.rate_limit_auth
if "/vote" in path:
return self.rate_limit_vote
return self.rate_limit_default
def _get_client_ip(self, request: Request) -> str:
"""Extract the client IP from the request, respecting X-Forwarded-For."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
async def _cleanup_old_entries(self) -> None:
"""Remove request timestamps older than 60 seconds for all IPs."""
now = time.time()
if now - self._last_cleanup < _CLEANUP_INTERVAL_SECONDS:
return
async with self._lock:
cutoff = now - 60
ips_to_delete: list[str] = []
for ip, timestamps in self._requests.items():
self._requests[ip] = [t for t in timestamps if t > cutoff]
if not self._requests[ip]:
ips_to_delete.append(ip)
for ip in ips_to_delete:
del self._requests[ip]
self._last_cleanup = now
if ips_to_delete:
logger.debug("Nettoyage rate limiter: %d IPs supprimees", len(ips_to_delete))
async def dispatch(self, request: Request, call_next) -> Response:
"""Check rate limit and either allow the request or return 429."""
# Skip rate limiting for WebSocket upgrades
if request.headers.get("upgrade", "").lower() == "websocket":
return await call_next(request)
# Trigger periodic cleanup
await self._cleanup_old_entries()
client_ip = self._get_client_ip(request)
path = request.url.path
limit = self._get_limit_for_path(path)
now = time.time()
window_start = now - 60
async with self._lock:
# Filter to requests within the last 60 seconds
self._requests[client_ip] = [
t for t in self._requests[client_ip] if t > window_start
]
request_count = len(self._requests[client_ip])
if request_count >= limit:
# Calculate when the oldest request in the window expires
oldest = min(self._requests[client_ip]) if self._requests[client_ip] else now
retry_after = int(oldest + 60 - now) + 1
retry_after = max(retry_after, 1)
reset_at = int(oldest + 60)
logger.warning(
"Rate limit depasse pour %s sur %s (%d/%d)",
client_ip, path, request_count, limit,
)
return JSONResponse(
status_code=429,
content={"detail": "Trop de requetes. Veuillez reessayer plus tard."},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(reset_at),
},
)
# Record this request
self._requests[client_ip].append(now)
remaining = max(0, limit - request_count - 1)
reset_at = int(now + 60)
# Process the request
response = await call_next(request)
# Add rate limit headers to the response
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(reset_at)
return response

View File

@@ -0,0 +1,42 @@
"""Security headers middleware: adds hardening headers to all responses.
Applies OWASP-recommended security headers to prevent common
web vulnerabilities (clickjacking, MIME sniffing, XSS, etc.).
"""
from __future__ import annotations
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all HTTP responses.
Headers applied:
- ``X-Content-Type-Options: nosniff``
- ``X-Frame-Options: DENY``
- ``X-XSS-Protection: 1; mode=block``
- ``Referrer-Policy: strict-origin-when-cross-origin``
- ``Content-Security-Policy: default-src 'self'``
- ``Strict-Transport-Security: max-age=31536000; includeSubDomains``
(only when the request was made over HTTPS)
"""
async def dispatch(self, request: Request, call_next) -> Response:
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = "default-src 'self'"
# Only add HSTS header for HTTPS requests
if request.url.scheme == "https":
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
return response

View File

@@ -0,0 +1,249 @@
"""Public API router: read-only endpoints for external consumption.
All endpoints are accessible without authentication.
No mutations allowed -- strictly read-only.
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.decision import Decision
from app.models.document import Document, DocumentItem
from app.models.sanctuary import SanctuaryEntry
from app.models.vote import VoteSession
from app.schemas.document import DocumentFullOut, DocumentItemOut, DocumentOut
from app.schemas.sanctuary import SanctuaryEntryOut
from app.services import document_service, sanctuary_service
router = APIRouter()
# ── Documents (public, read-only) ─────────────────────────────────────────
@router.get(
"/documents",
response_model=list[DocumentOut],
tags=["public-documents"],
summary="Liste publique des documents actifs",
)
async def list_documents(
db: AsyncSession = Depends(get_db),
doc_type: str | None = Query(default=None, description="Filtrer par type de document"),
status_filter: str | None = Query(default=None, alias="status", description="Filtrer par statut"),
skip: int = Query(default=0, ge=0),
limit: int = Query(default=50, ge=1, le=200),
) -> list[DocumentOut]:
"""Liste les documents de reference avec leurs items (lecture seule)."""
stmt = select(Document)
if doc_type is not None:
stmt = stmt.where(Document.doc_type == doc_type)
if status_filter is not None:
stmt = stmt.where(Document.status == status_filter)
stmt = stmt.order_by(Document.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(stmt)
documents = result.scalars().all()
out = []
for doc in documents:
count_result = await db.execute(
select(func.count()).select_from(DocumentItem).where(DocumentItem.document_id == doc.id)
)
items_count = count_result.scalar() or 0
doc_out = DocumentOut.model_validate(doc)
doc_out.items_count = items_count
out.append(doc_out)
return out
@router.get(
"/documents/{slug}",
response_model=DocumentFullOut,
tags=["public-documents"],
summary="Document complet avec ses items",
)
async def get_document(
slug: str,
db: AsyncSession = Depends(get_db),
) -> DocumentFullOut:
"""Recupere un document avec tous ses items (texte complet, serialise)."""
doc = await document_service.get_document_with_items(slug, db)
if doc is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Document introuvable",
)
return DocumentFullOut.model_validate(doc)
# ── Sanctuary (public, read-only) ─────────────────────────────────────────
@router.get(
"/sanctuary/entries",
response_model=list[SanctuaryEntryOut],
tags=["public-sanctuary"],
summary="Liste des entrees du sanctuaire",
)
async def list_sanctuary_entries(
db: AsyncSession = Depends(get_db),
entry_type: str | None = Query(default=None, description="Filtrer par type (document, decision, vote_result)"),
skip: int = Query(default=0, ge=0),
limit: int = Query(default=50, ge=1, le=200),
) -> list[SanctuaryEntryOut]:
"""Liste les entrees du sanctuaire (archives verifiees)."""
stmt = select(SanctuaryEntry)
if entry_type is not None:
stmt = stmt.where(SanctuaryEntry.entry_type == entry_type)
stmt = stmt.order_by(SanctuaryEntry.created_at.desc()).offset(skip).limit(limit)
result = await db.execute(stmt)
entries = result.scalars().all()
return [SanctuaryEntryOut.model_validate(e) for e in entries]
@router.get(
"/sanctuary/entries/{id}",
response_model=SanctuaryEntryOut,
tags=["public-sanctuary"],
summary="Entree du sanctuaire par ID",
)
async def get_sanctuary_entry(
id: uuid.UUID,
db: AsyncSession = Depends(get_db),
) -> SanctuaryEntryOut:
"""Recupere une entree du sanctuaire avec lien IPFS et ancrage on-chain."""
result = await db.execute(
select(SanctuaryEntry).where(SanctuaryEntry.id == id)
)
entry = result.scalar_one_or_none()
if entry is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Entree sanctuaire introuvable",
)
return SanctuaryEntryOut.model_validate(entry)
@router.get(
"/sanctuary/verify/{id}",
tags=["public-sanctuary"],
summary="Verification d'integrite d'une entree",
)
async def verify_sanctuary_entry(
id: uuid.UUID,
db: AsyncSession = Depends(get_db),
) -> dict:
"""Verifie l'integrite d'une entree du sanctuaire (comparaison de hash)."""
try:
result = await sanctuary_service.verify_entry(id, db)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(exc),
)
return result
# ── Votes (public, read-only) ─────────────────────────────────────────────
@router.get(
"/votes/sessions/{id}/result",
tags=["public-votes"],
summary="Resultat d'une session de vote",
)
async def get_vote_result(
id: uuid.UUID,
db: AsyncSession = Depends(get_db),
) -> dict:
"""Recupere le resultat d'une session de vote (lecture seule, public)."""
result = await db.execute(
select(VoteSession).where(VoteSession.id == id)
)
session = result.scalar_one_or_none()
if session is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Session de vote introuvable",
)
return {
"session_id": str(session.id),
"status": session.status,
"votes_for": session.votes_for,
"votes_against": session.votes_against,
"votes_total": session.votes_total,
"wot_size": session.wot_size,
"smith_size": session.smith_size,
"techcomm_size": session.techcomm_size,
"threshold_required": session.threshold_required,
"result": session.result,
"starts_at": session.starts_at.isoformat() if session.starts_at else None,
"ends_at": session.ends_at.isoformat() if session.ends_at else None,
"chain_recorded": session.chain_recorded,
"chain_tx_hash": session.chain_tx_hash,
}
# ── Platform status ───────────────────────────────────────────────────────
@router.get(
"/status",
tags=["public-status"],
summary="Statut de la plateforme",
)
async def platform_status(
db: AsyncSession = Depends(get_db),
) -> dict:
"""Statut general de la plateforme: compteurs de documents, decisions, votes actifs."""
# Count documents
doc_count_result = await db.execute(
select(func.count()).select_from(Document)
)
documents_count = doc_count_result.scalar() or 0
# Count decisions
decision_count_result = await db.execute(
select(func.count()).select_from(Decision)
)
decisions_count = decision_count_result.scalar() or 0
# Count active vote sessions
active_votes_result = await db.execute(
select(func.count()).select_from(VoteSession).where(VoteSession.status == "open")
)
active_votes_count = active_votes_result.scalar() or 0
# Count total vote sessions
total_votes_result = await db.execute(
select(func.count()).select_from(VoteSession)
)
total_votes_count = total_votes_result.scalar() or 0
# Count sanctuary entries
sanctuary_count_result = await db.execute(
select(func.count()).select_from(SanctuaryEntry)
)
sanctuary_count = sanctuary_count_result.scalar() or 0
return {
"platform": "Glibredecision",
"documents_count": documents_count,
"decisions_count": decisions_count,
"active_votes_count": active_votes_count,
"total_votes_count": total_votes_count,
"sanctuary_entries_count": sanctuary_count,
}

View File

@@ -1,14 +1,37 @@
"""WebSocket router: live vote updates."""
"""WebSocket router: live vote updates with authentication and heartbeat."""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from sqlalchemy import select
from app.database import async_session
from app.models.user import Session as UserSession
router = APIRouter()
logger = logging.getLogger(__name__)
# Heartbeat interval in seconds
_HEARTBEAT_INTERVAL = 30
# Valid notification event types
EVENT_TYPES = (
"vote_submitted",
"vote_update",
"session_closed",
"tally_update",
"decision_advanced",
"mandate_updated",
"document_changed",
"sanctuary_archived",
)
# ── Connection manager ──────────────────────────────────────────────────────
@@ -20,6 +43,8 @@ class ConnectionManager:
def __init__(self) -> None:
# session_id -> list of connected websockets
self._connections: dict[uuid.UUID, list[WebSocket]] = {}
# websocket -> authenticated identity_id (or None for anonymous)
self._authenticated: dict[WebSocket, uuid.UUID | None] = {}
async def connect(self, websocket: WebSocket, session_id: uuid.UUID) -> None:
"""Accept a WebSocket connection and register it for a vote session."""
@@ -36,6 +61,7 @@ class ConnectionManager:
]
if not self._connections[session_id]:
del self._connections[session_id]
self._authenticated.pop(websocket, None)
async def broadcast(self, session_id: uuid.UUID, data: dict[str, Any]) -> None:
"""Broadcast a message to all connections watching a given vote session."""
@@ -55,10 +81,127 @@ class ConnectionManager:
for ws in dead:
self.disconnect(ws, session_id)
async def broadcast_all(self, data: dict[str, Any]) -> None:
"""Broadcast a message to all connected WebSockets across all sessions."""
message = json.dumps(data, default=str)
for session_id in list(self._connections.keys()):
dead: list[WebSocket] = []
for ws in self._connections.get(session_id, []):
try:
await ws.send_text(message)
except Exception:
dead.append(ws)
for ws in dead:
self.disconnect(ws, session_id)
def set_authenticated(self, websocket: WebSocket, identity_id: uuid.UUID) -> None:
"""Mark a WebSocket connection as authenticated."""
self._authenticated[websocket] = identity_id
def is_authenticated(self, websocket: WebSocket) -> bool:
"""Check if a WebSocket connection is authenticated."""
return self._authenticated.get(websocket) is not None
manager = ConnectionManager()
# ── Authentication helper ──────────────────────────────────────────────────
async def _validate_token(token: str) -> uuid.UUID | None:
"""Validate a bearer token and return the associated identity_id.
Uses the same token hashing as auth_service but without FastAPI
dependency injection (since WebSocket doesn't use Depends the same way).
Parameters
----------
token:
The raw bearer token from the query parameter.
Returns
-------
uuid.UUID | None
The identity_id if valid, or None if invalid/expired.
"""
import hashlib
token_hash = hashlib.sha256(token.encode()).hexdigest()
try:
async with async_session() as db:
result = await db.execute(
select(UserSession).where(
UserSession.token_hash == token_hash,
UserSession.expires_at > datetime.now(timezone.utc),
)
)
session = result.scalar_one_or_none()
if session is not None:
return session.identity_id
except Exception:
logger.warning("Erreur lors de la validation du token WebSocket", exc_info=True)
return None
# ── Broadcast event helper (importable by other routers) ──────────────────
async def broadcast_event(
event_type: str,
payload: dict[str, Any],
session_id: uuid.UUID | None = None,
) -> None:
"""Broadcast a notification event to connected WebSocket clients.
This function is designed to be imported and called from other routers
(votes, decisions, mandates, etc.) to push real-time notifications.
Parameters
----------
event_type:
One of the valid EVENT_TYPES.
payload:
The event data to send.
session_id:
If provided, broadcast only to clients watching this specific session.
If None, broadcast to all connected clients.
"""
data = {
"event": event_type,
"timestamp": datetime.now(timezone.utc).isoformat(),
**payload,
}
if session_id is not None:
await manager.broadcast(session_id, data)
else:
await manager.broadcast_all(data)
# ── Heartbeat task ─────────────────────────────────────────────────────────
async def _heartbeat(websocket: WebSocket) -> None:
"""Send periodic ping messages to keep the connection alive.
Runs as a background task alongside the main message loop.
Sends a JSON ping every _HEARTBEAT_INTERVAL seconds.
"""
try:
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
try:
await websocket.send_text(
json.dumps({"event": "ping", "timestamp": datetime.now(timezone.utc).isoformat()})
)
except Exception:
break
except asyncio.CancelledError:
pass
# ── WebSocket endpoint ──────────────────────────────────────────────────────
@@ -66,23 +209,51 @@ manager = ConnectionManager()
async def live_updates(websocket: WebSocket) -> None:
"""WebSocket endpoint for live vote session updates.
The client connects and sends a JSON message with the session_id
they want to subscribe to:
Authentication (optional):
Connect with ``?token=<bearer_token>`` query parameter to authenticate.
If the token is valid, the connection is marked as authenticated.
If missing or invalid, the connection is accepted but unauthenticated.
The client sends JSON messages:
{ "action": "subscribe", "session_id": "<uuid>" }
The server will then push vote update events to the client:
{ "event": "vote_update", "session_id": "...", "votes_for": N, "votes_against": N, "votes_total": N }
{ "event": "session_closed", "session_id": "...", "result": "adopted|rejected" }
The client can also unsubscribe:
{ "action": "unsubscribe", "session_id": "<uuid>" }
The server pushes events:
{ "event": "vote_update", "session_id": "...", ... }
{ "event": "session_closed", "session_id": "...", ... }
{ "event": "vote_submitted", ... }
{ "event": "decision_advanced", ... }
{ "event": "mandate_updated", ... }
{ "event": "document_changed", ... }
{ "event": "sanctuary_archived", ... }
{ "event": "ping", "timestamp": "..." } (heartbeat)
"""
# Extract token from query parameters
token = websocket.query_params.get("token")
await websocket.accept()
subscribed_sessions: set[uuid.UUID] = set()
# Authenticate if token provided
if token:
identity_id = await _validate_token(token)
if identity_id is not None:
manager.set_authenticated(websocket, identity_id)
await websocket.send_text(
json.dumps({"event": "authenticated", "identity_id": str(identity_id)})
)
logger.debug("WebSocket authentifie: identity=%s", identity_id)
else:
await websocket.send_text(
json.dumps({"event": "auth_failed", "detail": "Token invalide ou expire"})
)
logger.debug("Echec authentification WebSocket (token invalide)")
# Start heartbeat task
heartbeat_task = asyncio.create_task(_heartbeat(websocket))
try:
while True:
raw = await websocket.receive_text()
@@ -138,3 +309,11 @@ async def live_updates(websocket: WebSocket) -> None:
# Clean up all subscriptions for this client
for session_id in subscribed_sessions:
manager.disconnect(websocket, session_id)
logger.debug("WebSocket deconnecte, %d souscriptions nettoyees", len(subscribed_sessions))
finally:
# Cancel the heartbeat task
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass

View File

@@ -3,85 +3,302 @@
Provides functions to query WoT size, Smith sub-WoT size, and
Technical Committee size from the Duniter V2 blockchain.
Currently stubbed with hardcoded values matching GDev test data.
Architecture:
1. Check database cache (via cache_service)
2. Try JSON-RPC call to Duniter node
3. Fall back to hardcoded GDev test values (with warning log)
All public functions accept a db session for cache access.
"""
from __future__ import annotations
import logging
async def get_wot_size() -> int:
import httpx
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.services import cache_service
logger = logging.getLogger(__name__)
# Hardcoded fallback values from GDev snapshot
_FALLBACK_WOT_SIZE = 7224
_FALLBACK_SMITH_SIZE = 20
_FALLBACK_TECHCOMM_SIZE = 5
# Cache key prefixes
_CACHE_KEY_WOT = "blockchain:wot_size"
_CACHE_KEY_SMITH = "blockchain:smith_size"
_CACHE_KEY_TECHCOMM = "blockchain:techcomm_size"
async def _fetch_from_rpc(method: str, params: list | None = None) -> dict | None:
"""Send a JSON-RPC POST request to the Duniter node.
Uses the HTTP variant of the RPC URL. If DUNITER_RPC_URL starts with
``wss://`` or ``ws://``, it is converted to ``https://`` or ``http://``
for the HTTP JSON-RPC endpoint.
Parameters
----------
method:
The RPC method name (e.g. ``"state_getStorage"``).
params:
Optional list of parameters for the RPC call.
Returns
-------
dict | None
The ``"result"`` field from the JSON-RPC response, or None on error.
"""
# Convert WebSocket URL to HTTP for JSON-RPC
rpc_url = settings.DUNITER_RPC_URL
if rpc_url.startswith("wss://"):
rpc_url = rpc_url.replace("wss://", "https://", 1)
elif rpc_url.startswith("ws://"):
rpc_url = rpc_url.replace("ws://", "http://", 1)
# Strip /ws suffix if present
if rpc_url.endswith("/ws"):
rpc_url = rpc_url[:-3]
payload = {
"jsonrpc": "2.0",
"id": 1,
"method": method,
"params": params or [],
}
try:
async with httpx.AsyncClient(
timeout=settings.DUNITER_RPC_TIMEOUT_SECONDS
) as client:
response = await client.post(rpc_url, json=payload)
response.raise_for_status()
data = response.json()
if "error" in data:
logger.warning(
"Erreur RPC Duniter pour %s: %s",
method, data["error"],
)
return None
return data.get("result")
except httpx.ConnectError:
logger.warning(
"Impossible de se connecter au noeud Duniter (%s)", rpc_url
)
return None
except httpx.TimeoutException:
logger.warning(
"Timeout lors de l'appel RPC Duniter pour %s (%s)",
method, rpc_url,
)
return None
except httpx.HTTPStatusError as exc:
logger.warning(
"Erreur HTTP Duniter RPC pour %s: %s",
method, exc.response.status_code,
)
return None
except Exception:
logger.warning(
"Erreur inattendue lors de l'appel RPC Duniter pour %s",
method,
exc_info=True,
)
return None
async def _fetch_membership_count(db: AsyncSession) -> int | None:
"""Fetch WoT membership count from the Duniter RPC.
Queries ``membership_membershipsCount`` via state RPC.
Returns
-------
int | None
The membership count, or None if the RPC call failed.
"""
# Try runtime API call for membership count
result = await _fetch_from_rpc("membership_membershipsCount", [])
if result is not None:
try:
count = int(result)
# Cache the result
await cache_service.set_cached(
_CACHE_KEY_WOT,
{"value": count},
db,
ttl_seconds=settings.BLOCKCHAIN_CACHE_TTL_SECONDS,
)
return count
except (ValueError, TypeError):
logger.warning("Reponse RPC invalide pour membership count: %s", result)
return None
async def _fetch_smith_count(db: AsyncSession) -> int | None:
"""Fetch Smith membership count from the Duniter RPC.
Returns
-------
int | None
The Smith member count, or None if the RPC call failed.
"""
result = await _fetch_from_rpc("smithMembers_smithMembersCount", [])
if result is not None:
try:
count = int(result)
await cache_service.set_cached(
_CACHE_KEY_SMITH,
{"value": count},
db,
ttl_seconds=settings.BLOCKCHAIN_CACHE_TTL_SECONDS,
)
return count
except (ValueError, TypeError):
logger.warning("Reponse RPC invalide pour smith count: %s", result)
return None
async def _fetch_techcomm_count(db: AsyncSession) -> int | None:
"""Fetch Technical Committee member count from the Duniter RPC.
Returns
-------
int | None
The TechComm member count, or None if the RPC call failed.
"""
result = await _fetch_from_rpc("technicalCommittee_members", [])
if result is not None:
try:
if isinstance(result, list):
count = len(result)
else:
count = int(result)
await cache_service.set_cached(
_CACHE_KEY_TECHCOMM,
{"value": count},
db,
ttl_seconds=settings.BLOCKCHAIN_CACHE_TTL_SECONDS,
)
return count
except (ValueError, TypeError):
logger.warning("Reponse RPC invalide pour techcomm count: %s", result)
return None
async def get_wot_size(db: AsyncSession) -> int:
"""Return the current number of WoT members.
TODO: Implement real RPC call using substrate-interface::
Resolution order:
1. Database cache (if not expired)
2. Duniter RPC call
3. Hardcoded fallback (7224, GDev snapshot)
from substrateinterface import SubstrateInterface
from app.config import settings
substrate = SubstrateInterface(url=settings.DUNITER_RPC_URL)
# Query membership count
result = substrate.query(
module="Membership",
storage_function="MembershipCount",
)
return int(result.value)
Parameters
----------
db:
Async database session (for cache access).
Returns
-------
int
Number of WoT members. Currently returns 7224 (GDev snapshot).
Number of WoT members.
"""
# TODO: Replace with real substrate-interface RPC call
return 7224
# 1. Try cache
cached = await cache_service.get_cached(_CACHE_KEY_WOT, db)
if cached is not None:
return cached["value"]
# 2. Try RPC
rpc_value = await _fetch_membership_count(db)
if rpc_value is not None:
return rpc_value
# 3. Fallback
logger.warning(
"Utilisation de la valeur WoT par defaut (%d) - "
"cache et RPC indisponibles",
_FALLBACK_WOT_SIZE,
)
return _FALLBACK_WOT_SIZE
async def get_smith_size() -> int:
async def get_smith_size(db: AsyncSession) -> int:
"""Return the current number of Smith members (forgerons).
TODO: Implement real RPC call using substrate-interface::
Resolution order:
1. Database cache (if not expired)
2. Duniter RPC call
3. Hardcoded fallback (20, GDev snapshot)
from substrateinterface import SubstrateInterface
from app.config import settings
substrate = SubstrateInterface(url=settings.DUNITER_RPC_URL)
# Query Smith membership count
result = substrate.query(
module="SmithMembers",
storage_function="SmithMembershipCount",
)
return int(result.value)
Parameters
----------
db:
Async database session (for cache access).
Returns
-------
int
Number of Smith members. Currently returns 20 (GDev snapshot).
Number of Smith members.
"""
# TODO: Replace with real substrate-interface RPC call
return 20
# 1. Try cache
cached = await cache_service.get_cached(_CACHE_KEY_SMITH, db)
if cached is not None:
return cached["value"]
# 2. Try RPC
rpc_value = await _fetch_smith_count(db)
if rpc_value is not None:
return rpc_value
# 3. Fallback
logger.warning(
"Utilisation de la valeur Smith par defaut (%d) - "
"cache et RPC indisponibles",
_FALLBACK_SMITH_SIZE,
)
return _FALLBACK_SMITH_SIZE
async def get_techcomm_size() -> int:
async def get_techcomm_size(db: AsyncSession) -> int:
"""Return the current number of Technical Committee members.
TODO: Implement real RPC call using substrate-interface::
Resolution order:
1. Database cache (if not expired)
2. Duniter RPC call
3. Hardcoded fallback (5, GDev snapshot)
from substrateinterface import SubstrateInterface
from app.config import settings
substrate = SubstrateInterface(url=settings.DUNITER_RPC_URL)
# Query TechComm member count
result = substrate.query(
module="TechnicalCommittee",
storage_function="Members",
)
return len(result.value) if result.value else 0
Parameters
----------
db:
Async database session (for cache access).
Returns
-------
int
Number of TechComm members. Currently returns 5 (GDev snapshot).
Number of TechComm members.
"""
# TODO: Replace with real substrate-interface RPC call
return 5
# 1. Try cache
cached = await cache_service.get_cached(_CACHE_KEY_TECHCOMM, db)
if cached is not None:
return cached["value"]
# 2. Try RPC
rpc_value = await _fetch_techcomm_count(db)
if rpc_value is not None:
return rpc_value
# 3. Fallback
logger.warning(
"Utilisation de la valeur TechComm par defaut (%d) - "
"cache et RPC indisponibles",
_FALLBACK_TECHCOMM_SIZE,
)
return _FALLBACK_TECHCOMM_SIZE

View File

@@ -0,0 +1,140 @@
"""Cache service: blockchain data caching with TTL expiry.
Uses the BlockchainCache model (PostgreSQL/JSONB) to cache
on-chain data like WoT size, Smith size, and TechComm size.
Avoids repeated RPC calls to the Duniter node.
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.cache import BlockchainCache
logger = logging.getLogger(__name__)
async def get_cached(key: str, db: AsyncSession) -> dict | None:
"""Retrieve a cached value by key if it has not expired.
Parameters
----------
key:
The cache key to look up.
db:
Async database session.
Returns
-------
dict | None
The cached value as a dict, or None if missing/expired.
"""
result = await db.execute(
select(BlockchainCache).where(
BlockchainCache.cache_key == key,
BlockchainCache.expires_at > datetime.now(timezone.utc),
)
)
entry = result.scalar_one_or_none()
if entry is None:
logger.debug("Cache miss pour la cle '%s'", key)
return None
logger.debug("Cache hit pour la cle '%s'", key)
return entry.cache_value
async def set_cached(
key: str,
value: dict,
db: AsyncSession,
ttl_seconds: int = 3600,
) -> None:
"""Store a value in the cache with the given TTL.
If the key already exists, it is replaced (upsert).
Parameters
----------
key:
The cache key.
value:
The value to store (must be JSON-serializable).
db:
Async database session.
ttl_seconds:
Time-to-live in seconds (default: 1 hour).
"""
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=ttl_seconds)
# Check if key already exists
result = await db.execute(
select(BlockchainCache).where(BlockchainCache.cache_key == key)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.cache_value = value
existing.fetched_at = now
existing.expires_at = expires_at
logger.debug("Cache mis a jour pour la cle '%s' (TTL=%ds)", key, ttl_seconds)
else:
entry = BlockchainCache(
cache_key=key,
cache_value=value,
fetched_at=now,
expires_at=expires_at,
)
db.add(entry)
logger.debug("Cache cree pour la cle '%s' (TTL=%ds)", key, ttl_seconds)
await db.commit()
async def invalidate(key: str, db: AsyncSession) -> None:
"""Remove a specific cache entry by key.
Parameters
----------
key:
The cache key to invalidate.
db:
Async database session.
"""
await db.execute(
delete(BlockchainCache).where(BlockchainCache.cache_key == key)
)
await db.commit()
logger.debug("Cache invalide pour la cle '%s'", key)
async def cleanup_expired(db: AsyncSession) -> int:
"""Remove all expired cache entries.
Parameters
----------
db:
Async database session.
Returns
-------
int
Number of entries removed.
"""
result = await db.execute(
delete(BlockchainCache).where(
BlockchainCache.expires_at <= datetime.now(timezone.utc)
)
)
await db.commit()
count = result.rowcount
if count > 0:
logger.info("Nettoyage cache: %d entrees expirees supprimees", count)
return count

View File

@@ -0,0 +1,215 @@
"""Tests for cache_service: get, set, invalidate, and cleanup.
Uses mock database sessions to test cache logic in isolation
without requiring a real PostgreSQL connection.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
sqlalchemy = pytest.importorskip("sqlalchemy", reason="sqlalchemy required for cache service tests")
from app.services.cache_service import ( # noqa: E402
cleanup_expired,
get_cached,
invalidate,
set_cached,
)
# ---------------------------------------------------------------------------
# Helpers: mock objects for BlockchainCache entries
# ---------------------------------------------------------------------------
def _make_cache_entry(
key: str = "test:key",
value: dict | None = None,
expired: bool = False,
) -> MagicMock:
"""Create a mock BlockchainCache entry."""
entry = MagicMock()
entry.id = uuid.uuid4()
entry.cache_key = key
entry.cache_value = value or {"data": 42}
entry.fetched_at = datetime.now(timezone.utc)
if expired:
entry.expires_at = datetime.now(timezone.utc) - timedelta(hours=1)
else:
entry.expires_at = datetime.now(timezone.utc) + timedelta(hours=1)
return entry
def _make_async_db(entry: MagicMock | None = None) -> AsyncMock:
"""Create a mock async database session.
The session's execute() returns a result with scalar_one_or_none()
that returns the given entry (or None).
"""
db = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = entry
db.execute = AsyncMock(return_value=result)
db.commit = AsyncMock()
db.add = MagicMock()
return db
# ---------------------------------------------------------------------------
# Tests: get_cached
# ---------------------------------------------------------------------------
class TestGetCached:
"""Test cache_service.get_cached."""
@pytest.mark.asyncio
async def test_returns_none_for_missing_key(self):
"""get_cached returns None when no entry exists for the key."""
db = _make_async_db(entry=None)
result = await get_cached("nonexistent:key", db)
assert result is None
@pytest.mark.asyncio
async def test_returns_value_for_valid_entry(self):
"""get_cached returns the cached value when entry exists and is not expired."""
entry = _make_cache_entry(key="blockchain:wot_size", value={"value": 7224})
db = _make_async_db(entry=entry)
result = await get_cached("blockchain:wot_size", db)
assert result == {"value": 7224}
@pytest.mark.asyncio
async def test_returns_none_for_expired_entry(self):
"""get_cached returns None when the entry is expired.
Note: In the real implementation, the SQL query filters by expires_at.
Here we test that when the DB returns None (as it would for expired),
get_cached returns None.
"""
db = _make_async_db(entry=None)
result = await get_cached("expired:key", db)
assert result is None
# ---------------------------------------------------------------------------
# Tests: set_cached + get_cached roundtrip
# ---------------------------------------------------------------------------
class TestSetCached:
"""Test cache_service.set_cached."""
@pytest.mark.asyncio
async def test_set_cached_creates_new_entry(self):
"""set_cached creates a new entry when key does not exist."""
db = _make_async_db(entry=None)
await set_cached("new:key", {"value": 100}, db, ttl_seconds=3600)
db.add.assert_called_once()
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_set_cached_updates_existing_entry(self):
"""set_cached updates the value when key already exists."""
existing = _make_cache_entry(key="existing:key", value={"value": 50})
db = _make_async_db(entry=existing)
await set_cached("existing:key", {"value": 200}, db, ttl_seconds=7200)
# Should update in-place, not add new
db.add.assert_not_called()
assert existing.cache_value == {"value": 200}
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_set_cached_roundtrip(self):
"""set_cached followed by get_cached returns the stored value."""
# First call: set (no existing entry)
db_set = _make_async_db(entry=None)
await set_cached("roundtrip:key", {"data": "test"}, db_set, ttl_seconds=3600)
db_set.add.assert_called_once()
# Extract the added entry from the mock
added_entry = db_set.add.call_args[0][0]
# Second call: get (returns the added entry)
db_get = _make_async_db(entry=added_entry)
result = await get_cached("roundtrip:key", db_get)
assert result == {"data": "test"}
# ---------------------------------------------------------------------------
# Tests: invalidate
# ---------------------------------------------------------------------------
class TestInvalidate:
"""Test cache_service.invalidate."""
@pytest.mark.asyncio
async def test_invalidate_removes_entry(self):
"""invalidate executes a delete and commits."""
db = AsyncMock()
db.execute = AsyncMock()
db.commit = AsyncMock()
await invalidate("some:key", db)
db.execute.assert_awaited_once()
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_invalidate_then_get_returns_none(self):
"""After invalidation, get_cached returns None."""
db = _make_async_db(entry=None)
result = await get_cached("invalidated:key", db)
assert result is None
# ---------------------------------------------------------------------------
# Tests: cleanup_expired
# ---------------------------------------------------------------------------
class TestCleanupExpired:
"""Test cache_service.cleanup_expired."""
@pytest.mark.asyncio
async def test_cleanup_removes_expired_entries(self):
"""cleanup_expired executes a delete for expired entries and returns count."""
db = AsyncMock()
# Mock the execute result with rowcount
exec_result = MagicMock()
exec_result.rowcount = 3
db.execute = AsyncMock(return_value=exec_result)
db.commit = AsyncMock()
count = await cleanup_expired(db)
assert count == 3
db.execute.assert_awaited_once()
db.commit.assert_awaited_once()
@pytest.mark.asyncio
async def test_cleanup_with_no_expired_entries(self):
"""cleanup_expired returns 0 when no expired entries exist."""
db = AsyncMock()
exec_result = MagicMock()
exec_result.rowcount = 0
db.execute = AsyncMock(return_value=exec_result)
db.commit = AsyncMock()
count = await cleanup_expired(db)
assert count == 0

View File

@@ -0,0 +1,222 @@
"""Tests for public API: basic schema validation and serialization.
Uses mock database sessions to test the public router logic
without requiring a real PostgreSQL connection.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock
import pytest
sqlalchemy = pytest.importorskip("sqlalchemy", reason="sqlalchemy required for public API tests")
from app.schemas.document import DocumentFullOut, DocumentOut # noqa: E402
from app.schemas.sanctuary import SanctuaryEntryOut # noqa: E402
# ---------------------------------------------------------------------------
# Helpers: mock objects
# ---------------------------------------------------------------------------
def _make_document_mock(
doc_id: uuid.UUID | None = None,
slug: str = "licence-g1",
title: str = "Licence G1",
doc_type: str = "licence",
version: str = "2.0.0",
doc_status: str = "active",
description: str | None = "La licence monetaire",
) -> MagicMock:
"""Create a mock Document for schema validation."""
doc = MagicMock()
doc.id = doc_id or uuid.uuid4()
doc.slug = slug
doc.title = title
doc.doc_type = doc_type
doc.version = version
doc.status = doc_status
doc.description = description
doc.ipfs_cid = None
doc.chain_anchor = None
doc.created_at = datetime.now(timezone.utc)
doc.updated_at = datetime.now(timezone.utc)
doc.items = []
return doc
def _make_item_mock(
item_id: uuid.UUID | None = None,
document_id: uuid.UUID | None = None,
position: str = "1",
item_type: str = "clause",
title: str | None = "Article 1",
current_text: str = "Texte de l'article",
sort_order: int = 0,
) -> MagicMock:
"""Create a mock DocumentItem for schema validation."""
item = MagicMock()
item.id = item_id or uuid.uuid4()
item.document_id = document_id or uuid.uuid4()
item.position = position
item.item_type = item_type
item.title = title
item.current_text = current_text
item.voting_protocol_id = None
item.sort_order = sort_order
item.created_at = datetime.now(timezone.utc)
item.updated_at = datetime.now(timezone.utc)
return item
def _make_sanctuary_entry_mock(
entry_id: uuid.UUID | None = None,
entry_type: str = "document",
reference_id: uuid.UUID | None = None,
title: str | None = "Licence G1 v2.0.0",
content_hash: str = "abc123def456",
ipfs_cid: str | None = "QmTestCid123",
chain_tx_hash: str | None = "0xdeadbeef",
) -> MagicMock:
"""Create a mock SanctuaryEntry for schema validation."""
entry = MagicMock()
entry.id = entry_id or uuid.uuid4()
entry.entry_type = entry_type
entry.reference_id = reference_id or uuid.uuid4()
entry.title = title
entry.content_hash = content_hash
entry.ipfs_cid = ipfs_cid
entry.chain_tx_hash = chain_tx_hash
entry.chain_block = 12345 if chain_tx_hash else None
entry.metadata_json = None
entry.created_at = datetime.now(timezone.utc)
return entry
# ---------------------------------------------------------------------------
# Tests: DocumentOut schema serialization
# ---------------------------------------------------------------------------
class TestDocumentOutSchema:
"""Test DocumentOut schema validation from mock objects."""
def test_document_out_basic(self):
"""DocumentOut validates from a mock document object."""
doc = _make_document_mock()
out = DocumentOut.model_validate(doc)
assert out.slug == "licence-g1"
assert out.title == "Licence G1"
assert out.doc_type == "licence"
assert out.version == "2.0.0"
assert out.status == "active"
# items_count is set explicitly after validation (not from model)
out.items_count = 0
assert out.items_count == 0
def test_document_out_with_items_count(self):
"""DocumentOut can have items_count set after validation."""
doc = _make_document_mock()
out = DocumentOut.model_validate(doc)
out.items_count = 42
assert out.items_count == 42
def test_document_out_all_fields_present(self):
"""All expected fields are present in the DocumentOut serialization."""
doc = _make_document_mock()
out = DocumentOut.model_validate(doc)
data = out.model_dump()
expected_fields = {
"id", "slug", "title", "doc_type", "version", "status",
"description", "ipfs_cid", "chain_anchor", "created_at",
"updated_at", "items_count",
}
assert expected_fields.issubset(set(data.keys()))
# ---------------------------------------------------------------------------
# Tests: DocumentFullOut schema serialization
# ---------------------------------------------------------------------------
class TestDocumentFullOutSchema:
"""Test DocumentFullOut schema validation (document with items)."""
def test_document_full_out_empty_items(self):
"""DocumentFullOut works with an empty items list."""
doc = _make_document_mock()
doc.items = []
out = DocumentFullOut.model_validate(doc)
assert out.slug == "licence-g1"
assert out.items == []
def test_document_full_out_with_items(self):
"""DocumentFullOut includes items when present."""
doc_id = uuid.uuid4()
doc = _make_document_mock(doc_id=doc_id)
item1 = _make_item_mock(document_id=doc_id, position="1", sort_order=0)
item2 = _make_item_mock(document_id=doc_id, position="2", sort_order=1)
doc.items = [item1, item2]
out = DocumentFullOut.model_validate(doc)
assert len(out.items) == 2
assert out.items[0].position == "1"
assert out.items[1].position == "2"
# ---------------------------------------------------------------------------
# Tests: SanctuaryEntryOut schema serialization
# ---------------------------------------------------------------------------
class TestSanctuaryEntryOutSchema:
"""Test SanctuaryEntryOut schema validation."""
def test_sanctuary_entry_out_basic(self):
"""SanctuaryEntryOut validates from a mock entry."""
entry = _make_sanctuary_entry_mock()
out = SanctuaryEntryOut.model_validate(entry)
assert out.entry_type == "document"
assert out.content_hash == "abc123def456"
assert out.ipfs_cid == "QmTestCid123"
assert out.chain_tx_hash == "0xdeadbeef"
assert out.chain_block == 12345
def test_sanctuary_entry_out_without_ipfs(self):
"""SanctuaryEntryOut works when IPFS CID is None."""
entry = _make_sanctuary_entry_mock(ipfs_cid=None, chain_tx_hash=None)
out = SanctuaryEntryOut.model_validate(entry)
assert out.ipfs_cid is None
assert out.chain_tx_hash is None
assert out.chain_block is None
def test_sanctuary_entry_out_all_fields(self):
"""All expected fields are present in SanctuaryEntryOut."""
entry = _make_sanctuary_entry_mock()
out = SanctuaryEntryOut.model_validate(entry)
data = out.model_dump()
expected_fields = {
"id", "entry_type", "reference_id", "title",
"content_hash", "ipfs_cid", "chain_tx_hash",
"chain_block", "metadata_json", "created_at",
}
assert expected_fields.issubset(set(data.keys()))
def test_sanctuary_entry_types(self):
"""Different entry_type values are accepted."""
for entry_type in ("document", "decision", "vote_result"):
entry = _make_sanctuary_entry_mock(entry_type=entry_type)
out = SanctuaryEntryOut.model_validate(entry)
assert out.entry_type == entry_type