use axum::{ extract::State, http::StatusCode, Json, }; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::OnceCell; use std::time::{Duration, Instant}; use tokio::sync::Mutex; use crate::AppState; /// Cached model list with expiry. static MODEL_CACHE: OnceCell> = OnceCell::const_new(); struct CachedModels { models: Vec, fetched_at: Instant, } const CACHE_TTL: Duration = Duration::from_secs(60 * 30); // 30 minutes #[derive(Debug, Clone, Serialize)] pub struct ModelInfo { pub id: String, pub name: String, pub context_length: Option, pub pricing_prompt: Option, pub pricing_completion: Option, pub supports_vision: bool, } #[derive(Debug, Deserialize)] struct OpenRouterModelsResponse { data: Vec, } #[derive(Debug, Deserialize)] struct OpenRouterModel { id: String, name: String, context_length: Option, pricing: Option, architecture: Option, } #[derive(Debug, Deserialize)] struct OpenRouterPricing { prompt: Option, completion: Option, } #[derive(Debug, Deserialize)] struct OpenRouterArchitecture { input_modalities: Option>, } async fn fetch_models(api_key: &str) -> Result, String> { let client = reqwest::Client::new(); let response = client .get("https://openrouter.ai/api/v1/models") .header("Authorization", format!("Bearer {}", api_key)) .send() .await .map_err(|e| format!("Failed to fetch models: {}", e))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); return Err(format!("OpenRouter models API error {}: {}", status, body)); } let data: OpenRouterModelsResponse = response .json() .await .map_err(|e| format!("Failed to parse models response: {}", e))?; let models: Vec = data .data .into_iter() .map(|m| { let pricing = m.pricing.as_ref(); let supports_vision = m.architecture .as_ref() .and_then(|a| a.input_modalities.as_ref()) .map(|mods| mods.iter().any(|m| m == "image")) .unwrap_or(false); ModelInfo { id: m.id, name: m.name, context_length: m.context_length, pricing_prompt: pricing.and_then(|p| p.prompt.clone()), pricing_completion: pricing.and_then(|p| p.completion.clone()), supports_vision, } }) .collect(); tracing::info!("Fetched {} models from OpenRouter", models.len()); Ok(models) } pub async fn list_models( State(state): State>, ) -> Result>, (StatusCode, String)> { let cache = MODEL_CACHE .get_or_init(|| async { Mutex::new(CachedModels { models: Vec::new(), fetched_at: Instant::now() - CACHE_TTL - Duration::from_secs(1), // expired }) }) .await; let mut cached = cache.lock().await; if cached.models.is_empty() || cached.fetched_at.elapsed() > CACHE_TTL { match fetch_models(&state.openrouter_key).await { Ok(models) => { cached.models = models; cached.fetched_at = Instant::now(); } Err(e) => { // If we have stale data, return it rather than erroring if !cached.models.is_empty() { tracing::warn!("Failed to refresh models, returning stale cache: {}", e); return Ok(Json(cached.models.clone())); } return Err((StatusCode::BAD_GATEWAY, e)); } } } Ok(Json(cached.models.clone())) }