"""Rate limiter middleware: in-memory IP-based request throttling. Tracks requests per IP address using a sliding window approach. Configurable limits per endpoint category (general, auth, vote). Returns 429 Too Many Requests with Retry-After header when exceeded. """ from __future__ import annotations import asyncio import logging import time from collections import defaultdict from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response logger = logging.getLogger(__name__) # Cleanup interval: remove expired entries every 5 minutes _CLEANUP_INTERVAL_SECONDS = 300 class RateLimiterMiddleware(BaseHTTPMiddleware): """In-memory rate limiter middleware. Tracks request timestamps per IP and enforces configurable limits: - General endpoints: ``rate_limit_default`` requests/min - Auth endpoints (``/auth``): ``rate_limit_auth`` requests/min - Vote endpoints (``/vote``): ``rate_limit_vote`` requests/min Adds standard rate-limit headers to all responses: - ``X-RateLimit-Limit`` - ``X-RateLimit-Remaining`` - ``X-RateLimit-Reset`` Parameters ---------- app: The ASGI application. rate_limit_default: Maximum requests per minute for general endpoints. rate_limit_auth: Maximum requests per minute for auth endpoints. rate_limit_vote: Maximum requests per minute for vote endpoints. """ def __init__( self, app, rate_limit_default: int = 60, rate_limit_auth: int = 10, rate_limit_vote: int = 30, ) -> None: super().__init__(app) self.rate_limit_default = rate_limit_default self.rate_limit_auth = rate_limit_auth self.rate_limit_vote = rate_limit_vote # IP -> list of timestamps (epoch seconds) self._requests: dict[str, list[float]] = defaultdict(list) self._last_cleanup: float = time.time() self._lock = asyncio.Lock() def _get_limit_for_path(self, path: str) -> int: """Return the rate limit applicable to the given request path.""" if "/auth" in path: return self.rate_limit_auth if "/vote" in path: return self.rate_limit_vote return self.rate_limit_default def _get_client_ip(self, request: Request) -> str: """Extract the client IP from the request, respecting X-Forwarded-For.""" forwarded = request.headers.get("x-forwarded-for") if forwarded: return forwarded.split(",")[0].strip() return request.client.host if request.client else "unknown" async def _cleanup_old_entries(self) -> None: """Remove request timestamps older than 60 seconds for all IPs.""" now = time.time() if now - self._last_cleanup < _CLEANUP_INTERVAL_SECONDS: return async with self._lock: cutoff = now - 60 ips_to_delete: list[str] = [] for ip, timestamps in self._requests.items(): self._requests[ip] = [t for t in timestamps if t > cutoff] if not self._requests[ip]: ips_to_delete.append(ip) for ip in ips_to_delete: del self._requests[ip] self._last_cleanup = now if ips_to_delete: logger.debug("Nettoyage rate limiter: %d IPs supprimees", len(ips_to_delete)) async def dispatch(self, request: Request, call_next) -> Response: """Check rate limit and either allow the request or return 429.""" # Skip rate limiting for WebSocket upgrades if request.headers.get("upgrade", "").lower() == "websocket": return await call_next(request) # Trigger periodic cleanup await self._cleanup_old_entries() client_ip = self._get_client_ip(request) path = request.url.path limit = self._get_limit_for_path(path) now = time.time() window_start = now - 60 async with self._lock: # Filter to requests within the last 60 seconds self._requests[client_ip] = [ t for t in self._requests[client_ip] if t > window_start ] request_count = len(self._requests[client_ip]) if request_count >= limit: # Calculate when the oldest request in the window expires oldest = min(self._requests[client_ip]) if self._requests[client_ip] else now retry_after = int(oldest + 60 - now) + 1 retry_after = max(retry_after, 1) reset_at = int(oldest + 60) logger.warning( "Rate limit depasse pour %s sur %s (%d/%d)", client_ip, path, request_count, limit, ) return JSONResponse( status_code=429, content={"detail": "Trop de requetes. Veuillez reessayer plus tard."}, headers={ "Retry-After": str(retry_after), "X-RateLimit-Limit": str(limit), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(reset_at), }, ) # Record this request self._requests[client_ip].append(now) remaining = max(0, limit - request_count - 1) reset_at = int(now + 60) # Process the request response = await call_next(request) # Add rate limit headers to the response response.headers["X-RateLimit-Limit"] = str(limit) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(reset_at) return response