mod handlers; mod middleware; mod models; mod services; use axum::{ routing::{get, post, put}, Router, }; use sqlx::sqlite::SqlitePoolOptions; use std::sync::Arc; use tokio::sync::broadcast; use tower_http::cors::{Any, CorsLayer}; use tower_http::services::{ServeDir, ServeFile}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; /// Extract the file path from a SQLite DATABASE_URL like "sqlite:chat.db?mode=rwc" fn db_file_path(database_url: &str) -> Option { let path = database_url.strip_prefix("sqlite:")?; // Strip query params like ?mode=rwc let path = path.split('?').next().unwrap_or(path); Some(path.to_string()) } /// Create a timestamped backup of the SQLite database file. /// Backups are stored in a `backups/` directory next to the db file. /// Only keeps the 10 most recent backups to avoid unbounded disk usage. fn backup_database(database_url: &str) { let Some(db_path) = db_file_path(database_url) else { tracing::warn!("Could not parse database path from URL, skipping backup"); return; }; let db_file = std::path::Path::new(&db_path); if !db_file.exists() { tracing::info!("Database file does not exist yet, skipping backup"); return; } // Create backups directory next to the database let backup_dir = db_file .parent() .unwrap_or(std::path::Path::new(".")) .join("backups"); if let Err(e) = std::fs::create_dir_all(&backup_dir) { tracing::error!("Failed to create backup directory: {}", e); return; } // Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db let stem = db_file .file_stem() .and_then(|s| s.to_str()) .unwrap_or("db"); let ext = db_file .extension() .and_then(|s| s.to_str()) .unwrap_or("db"); let now = chrono::Local::now(); let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext); let backup_path = backup_dir.join(&backup_name); match std::fs::copy(db_file, &backup_path) { Ok(bytes) => { tracing::info!( "Database backup created: {} ({:.1} KB)", backup_path.display(), bytes as f64 / 1024.0 ); } Err(e) => { tracing::error!("Failed to backup database: {}", e); return; } } // Also copy WAL and SHM files if they exist (for consistency) let wal_path = format!("{}-wal", db_path); let shm_path = format!("{}-shm", db_path); if std::path::Path::new(&wal_path).exists() { let wal_backup = backup_dir.join(format!("{}_{}.{}-wal", stem, now.format("%Y-%m-%d_%H%M%S"), ext)); let _ = std::fs::copy(&wal_path, &wal_backup); } if std::path::Path::new(&shm_path).exists() { let shm_backup = backup_dir.join(format!("{}_{}.{}-shm", stem, now.format("%Y-%m-%d_%H%M%S"), ext)); let _ = std::fs::copy(&shm_path, &shm_backup); } // Prune old backups: keep only the 10 most recent prune_old_backups(&backup_dir, stem, 10); } /// Remove old backups, keeping only the `keep` most recent ones. fn prune_old_backups(backup_dir: &std::path::Path, stem: &str, keep: usize) { let prefix = format!("{}_", stem); let mut backups: Vec<_> = std::fs::read_dir(backup_dir) .into_iter() .flatten() .filter_map(|e| e.ok()) .filter(|e| { let name = e.file_name(); let name = name.to_string_lossy(); // Match main db backups (not -wal/-shm) name.starts_with(&prefix) && !name.ends_with("-wal") && !name.ends_with("-shm") }) .collect(); if backups.len() <= keep { return; } // Sort by filename (timestamps sort lexicographically) backups.sort_by_key(|e| e.file_name()); let to_remove = backups.len() - keep; for entry in backups.into_iter().take(to_remove) { let path = entry.path(); let name = path.file_name().unwrap_or_default().to_string_lossy().to_string(); if let Err(e) = std::fs::remove_file(&path) { tracing::warn!("Failed to remove old backup {}: {}", name, e); } else { tracing::debug!("Pruned old backup: {}", name); // Also remove associated WAL/SHM backups let wal = path.with_extension(format!("{}-wal", path.extension().unwrap_or_default().to_string_lossy())); let shm = path.with_extension(format!("{}-shm", path.extension().unwrap_or_default().to_string_lossy())); let _ = std::fs::remove_file(&wal); let _ = std::fs::remove_file(&shm); } } } pub struct AppState { pub db: sqlx::SqlitePool, pub jwt_secret: String, pub openrouter_key: String, pub brave_api_key: String, pub tx: broadcast::Sender, } #[tokio::main] async fn main() { dotenvy::dotenv().ok(); tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()), )) .with(tracing_subscriber::fmt::layer()) .init(); let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into()); let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into()); let openrouter_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); let brave_api_key = std::env::var("BRAVE_API_KEY").expect("BRAVE_API_KEY must be set"); // Backup the database before connecting and running migrations backup_database(&database_url); let db = SqlitePoolOptions::new() .max_connections(5) .connect(&database_url) .await .expect("Failed to connect to database"); // Run migrations let migration_sql = include_str!("../migrations/001_init.sql"); sqlx::raw_sql(migration_sql) .execute(&db) .await .expect("Failed to run migrations"); // Run migration 002 - soft delete let migration_002 = include_str!("../migrations/002_soft_delete.sql"); match sqlx::raw_sql(migration_002).execute(&db).await { Ok(_) => tracing::info!("Migration 002 applied"), Err(e) if e.to_string().contains("duplicate column") => { tracing::debug!("Migration 002 already applied, skipping"); } Err(e) => panic!("Failed to run migration 002: {}", e), } // Run migration 003 - ai_meta on messages let migration_003 = include_str!("../migrations/003_ai_meta.sql"); match sqlx::raw_sql(migration_003).execute(&db).await { Ok(_) => tracing::info!("Migration 003 applied"), Err(e) if e.to_string().contains("duplicate column") => { tracing::debug!("Migration 003 already applied, skipping"); } Err(e) => panic!("Failed to run migration 003: {}", e), } // Run migration 004 - ai_name on rooms let migration_004 = include_str!("../migrations/004_ai_name.sql"); match sqlx::raw_sql(migration_004).execute(&db).await { Ok(_) => tracing::info!("Migration 004 applied"), Err(e) if e.to_string().contains("duplicate column") => { tracing::debug!("Migration 004 already applied, skipping"); } Err(e) => panic!("Failed to run migration 004: {}", e), } // Run migration 005 - avatar_url on users let migration_005 = include_str!("../migrations/005_avatar.sql"); match sqlx::raw_sql(migration_005).execute(&db).await { Ok(_) => tracing::info!("Migration 005 applied"), Err(e) if e.to_string().contains("duplicate column") => { tracing::debug!("Migration 005 already applied, skipping"); } Err(e) => panic!("Failed to run migration 005: {}", e), } // Run migration 006 - image_url on messages let migration_006 = include_str!("../migrations/006_image_url.sql"); match sqlx::raw_sql(migration_006).execute(&db).await { Ok(_) => tracing::info!("Migration 006 applied"), Err(e) if e.to_string().contains("duplicate column") => { tracing::debug!("Migration 006 already applied, skipping"); } Err(e) => panic!("Failed to run migration 006: {}", e), } // Run migration 007 - SHA-256 integrity hash on messages let migration_007 = include_str!("../migrations/007_message_hash.sql"); match sqlx::raw_sql(migration_007).execute(&db).await { Ok(_) => { tracing::info!("Migration 007 applied, backfilling hashes for existing messages..."); // Backfill hashes for all existing messages that don't have one let rows = sqlx::query_as::<_, (String, String, String)>( "SELECT id, created_at, content FROM messages WHERE hash IS NULL", ) .fetch_all(&db) .await .unwrap_or_default(); let count = rows.len(); for (id, created_at, content) in rows { let hash = models::message_hash(&created_at, &content); let _ = sqlx::query("UPDATE messages SET hash = ? WHERE id = ?") .bind(&hash) .bind(&id) .execute(&db) .await; } if count > 0 { tracing::info!("Backfilled hashes for {} existing messages", count); } } Err(e) if e.to_string().contains("duplicate column") => { tracing::debug!("Migration 007 already applied, skipping"); } Err(e) => panic!("Failed to run migration 007: {}", e), } tracing::info!("Database initialized"); let (tx, _rx) = broadcast::channel::(4096); let state = Arc::new(AppState { db, jwt_secret, openrouter_key, brave_api_key, tx, }); let cors = CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_headers(Any); // Serve static files from client dist in production let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into()); let api_routes = Router::new() // Auth routes .route("/api/auth/register", post(handlers::auth::register)) .route("/api/auth/login", post(handlers::auth::login)) .route("/api/auth/me", get(handlers::auth::me)) // Profile routes .route("/api/auth/profile", put(handlers::profile::update_profile)) .route("/api/auth/avatar", post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar)) // Room routes .route("/api/rooms", get(handlers::rooms::list_rooms).post(handlers::rooms::create_room)) .route("/api/rooms/:room_id", get(handlers::rooms::get_room).delete(handlers::rooms::delete_room)) .route("/api/rooms/:room_id/messages", get(handlers::rooms::get_messages)) .route("/api/rooms/:room_id/join", post(handlers::rooms::join_room)) .route("/api/rooms/:room_id/clear", post(handlers::rooms::clear_room)) .route("/api/messages/hash/:hash", get(handlers::rooms::resolve_message_hash)) // Upload (chat images) .route("/api/upload", post(handlers::upload::upload_chat_image)) // Models .route("/api/models", get(handlers::models::list_models)) // Invite routes .route("/api/invites", post(handlers::invites::create_invite)) .route("/api/invites/:token/accept", post(handlers::invites::accept_invite)) // Uploaded files (avatars) .nest_service("/uploads", ServeDir::new("uploads")) // WebSocket .route("/ws", get(handlers::ws::ws_handler)) .layer(cors) .with_state(state); // SPA fallback: serve static assets, fall back to index.html for client-side routing let spa = ServeDir::new(&static_dir) .not_found_service(ServeFile::new(format!("{}/index.html", static_dir))); let app = api_routes.fallback_service(spa); let addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "0.0.0.0:3001".into()); tracing::info!("Server starting on {}", addr); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app).await.unwrap(); }