//! Tiny per-IP token-bucket rate limiter. //! //! Currently used by `POST /v1/messages` to keep a single source from //! filling everyone's mailbox until disk fills (the v0.1 spam concern //! called out in messages.rs). Default rate: 60 messages/min per IP. //! //! Why not pull in `tower_governor` or `governor`? They're great //! crates but each adds 10+ transitive deps for what's structurally //! ~50 lines of code. We're already shipping nostr-sdk's dep tree; //! restraint here keeps the build snappy. //! //! Client IP resolution priority: //! 1. `CF-Connecting-IP` header — Cloudflare puts the real client //! IP here; we trust it because Cloudflare strips this header //! from anything that wasn't routed through our tunnel. //! 2. `X-Forwarded-For` (first hop) — fallback for non-Cloudflare //! deployments. //! 3. None — direct curl / loopback. Rate-limit by `0.0.0.0` so //! noisy test traffic still gets bucketed instead of bypassing. use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; use std::time::{Duration, Instant}; use axum::http::HeaderMap; use tokio::sync::Mutex; /// One bucket per client. We store the residual token count + the /// last refill timestamp; on each `try_acquire` we compute how many /// tokens to add based on elapsed time, then either decrement or fail. #[derive(Debug, Clone)] struct Bucket { /// Tokens currently available. tokens: f64, /// Last time we refilled. last_refill: Instant, /// Most recent activity — used by the eviction sweep to drop /// long-cold buckets so the HashMap doesn't grow forever. last_seen: Instant, } #[derive(Debug, Clone, Copy)] pub struct RateLimitConfig { /// Bucket capacity (max burst). Once exhausted, callers fail /// fast until enough time passes to refill ≥1 token. pub capacity: u32, /// Refill rate in tokens per second. pub refill_per_sec: f64, /// Buckets idle longer than this get evicted on next sweep so /// short-lived clients don't pile up in memory. pub idle_ttl: Duration, } impl Default for RateLimitConfig { fn default() -> Self { Self { capacity: 60, refill_per_sec: 1.0, // = 60/min steady-state idle_ttl: Duration::from_secs(15 * 60), } } } /// Process-shared rate limiter. Cheap to clone (Arc inside). #[derive(Clone)] pub struct RateLimiter { inner: Arc>>, config: RateLimitConfig, } impl RateLimiter { pub fn new(config: RateLimitConfig) -> Self { Self { inner: Arc::new(Mutex::new(HashMap::new())), config, } } /// Drain one token for `ip` if available. Returns `true` on /// success (caller may proceed) or `false` if rate-limited /// (caller should respond 429). pub async fn try_acquire(&self, ip: IpAddr) -> bool { let mut map = self.inner.lock().await; let now = Instant::now(); let bucket = map.entry(ip).or_insert(Bucket { tokens: self.config.capacity as f64, last_refill: now, last_seen: now, }); // Refill since last touch. let elapsed = now.saturating_duration_since(bucket.last_refill).as_secs_f64(); bucket.tokens = (bucket.tokens + elapsed * self.config.refill_per_sec) .min(self.config.capacity as f64); bucket.last_refill = now; bucket.last_seen = now; if bucket.tokens >= 1.0 { bucket.tokens -= 1.0; true } else { false } } /// Periodically called by a background sweep to drop buckets /// for clients we haven't heard from in `idle_ttl`. Returns the /// number of buckets removed (diagnostic). pub async fn sweep(&self) -> usize { let now = Instant::now(); let mut map = self.inner.lock().await; let before = map.len(); map.retain(|_, b| now.saturating_duration_since(b.last_seen) < self.config.idle_ttl); before - map.len() } } /// Resolve the client IP from the request headers, with the /// Cloudflare-first priority documented above. Falls back to /// `0.0.0.0` if we can't extract anything sensible — that way /// direct curl traffic still gets rate-limited as a single /// "anonymous" client instead of bypassing entirely. pub fn client_ip_from_headers(headers: &HeaderMap) -> IpAddr { if let Some(v) = headers.get("CF-Connecting-IP").and_then(|h| h.to_str().ok()) { if let Ok(ip) = v.trim().parse::() { return ip; } } if let Some(v) = headers.get("X-Forwarded-For").and_then(|h| h.to_str().ok()) { // X-Forwarded-For is a comma-separated list; the leftmost // value is the original client. if let Some(first) = v.split(',').next() { if let Ok(ip) = first.trim().parse::() { return ip; } } } "0.0.0.0".parse().expect("0.0.0.0 is a valid IpAddr") } #[cfg(test)] mod tests { use super::*; fn cfg_for_test() -> RateLimitConfig { RateLimitConfig { capacity: 3, refill_per_sec: 10.0, idle_ttl: Duration::from_secs(1), } } #[tokio::test] async fn within_capacity_succeeds() { let rl = RateLimiter::new(cfg_for_test()); let ip: IpAddr = "1.2.3.4".parse().unwrap(); for _ in 0..3 { assert!(rl.try_acquire(ip).await); } } #[tokio::test] async fn exhausting_capacity_fails_then_recovers() { let rl = RateLimiter::new(cfg_for_test()); let ip: IpAddr = "1.2.3.4".parse().unwrap(); for _ in 0..3 { assert!(rl.try_acquire(ip).await); } assert!(!rl.try_acquire(ip).await, "4th request should be rate-limited"); // Refill rate is 10 tokens/sec → 1 token in 100ms. tokio::time::sleep(Duration::from_millis(150)).await; assert!(rl.try_acquire(ip).await); } #[tokio::test] async fn separate_ips_have_separate_buckets() { let rl = RateLimiter::new(cfg_for_test()); let a: IpAddr = "1.2.3.4".parse().unwrap(); let b: IpAddr = "5.6.7.8".parse().unwrap(); for _ in 0..3 { assert!(rl.try_acquire(a).await); } assert!(rl.try_acquire(b).await, "different IP should still have full bucket"); } #[test] fn cf_header_wins_over_xff() { let mut h = HeaderMap::new(); h.insert("CF-Connecting-IP", "9.9.9.9".parse().unwrap()); h.insert("X-Forwarded-For", "8.8.8.8, 7.7.7.7".parse().unwrap()); assert_eq!(client_ip_from_headers(&h).to_string(), "9.9.9.9"); } #[test] fn xff_first_hop() { let mut h = HeaderMap::new(); h.insert("X-Forwarded-For", "8.8.8.8, 7.7.7.7".parse().unwrap()); assert_eq!(client_ip_from_headers(&h).to_string(), "8.8.8.8"); } #[test] fn fallback_when_no_headers() { let h = HeaderMap::new(); assert_eq!(client_ip_from_headers(&h).to_string(), "0.0.0.0"); } }