groupchat/server/src/main.rs
Jason Tudisco 55c17b2999 fix: support hash-only permalink format with server-side resolution
Add /api/messages/hash/:hash endpoint that resolves a message hash to
its room ID (with membership check). The client now handles both
#roomId/hash and #hash formats - the latter calls the API to find
which room the message belongs to, then loads it and scrolls.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-16 17:44:44 -06:00

316 lines
12 KiB
Rust

mod handlers;
mod middleware;
mod models;
mod services;
use axum::{
routing::{get, post, put},
Router,
};
use sqlx::sqlite::SqlitePoolOptions;
use std::sync::Arc;
use tokio::sync::broadcast;
use tower_http::cors::{Any, CorsLayer};
use tower_http::services::{ServeDir, ServeFile};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
/// Extract the file path from a SQLite DATABASE_URL like "sqlite:chat.db?mode=rwc"
fn db_file_path(database_url: &str) -> Option<String> {
let path = database_url.strip_prefix("sqlite:")?;
// Strip query params like ?mode=rwc
let path = path.split('?').next().unwrap_or(path);
Some(path.to_string())
}
/// Create a timestamped backup of the SQLite database file.
/// Backups are stored in a `backups/` directory next to the db file.
/// Only keeps the 10 most recent backups to avoid unbounded disk usage.
fn backup_database(database_url: &str) {
let Some(db_path) = db_file_path(database_url) else {
tracing::warn!("Could not parse database path from URL, skipping backup");
return;
};
let db_file = std::path::Path::new(&db_path);
if !db_file.exists() {
tracing::info!("Database file does not exist yet, skipping backup");
return;
}
// Create backups directory next to the database
let backup_dir = db_file
.parent()
.unwrap_or(std::path::Path::new("."))
.join("backups");
if let Err(e) = std::fs::create_dir_all(&backup_dir) {
tracing::error!("Failed to create backup directory: {}", e);
return;
}
// Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db
let stem = db_file
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db");
let ext = db_file
.extension()
.and_then(|s| s.to_str())
.unwrap_or("db");
let now = chrono::Local::now();
let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext);
let backup_path = backup_dir.join(&backup_name);
match std::fs::copy(db_file, &backup_path) {
Ok(bytes) => {
tracing::info!(
"Database backup created: {} ({:.1} KB)",
backup_path.display(),
bytes as f64 / 1024.0
);
}
Err(e) => {
tracing::error!("Failed to backup database: {}", e);
return;
}
}
// Also copy WAL and SHM files if they exist (for consistency)
let wal_path = format!("{}-wal", db_path);
let shm_path = format!("{}-shm", db_path);
if std::path::Path::new(&wal_path).exists() {
let wal_backup = backup_dir.join(format!("{}_{}.{}-wal", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
let _ = std::fs::copy(&wal_path, &wal_backup);
}
if std::path::Path::new(&shm_path).exists() {
let shm_backup = backup_dir.join(format!("{}_{}.{}-shm", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
let _ = std::fs::copy(&shm_path, &shm_backup);
}
// Prune old backups: keep only the 10 most recent
prune_old_backups(&backup_dir, stem, 10);
}
/// Remove old backups, keeping only the `keep` most recent ones.
fn prune_old_backups(backup_dir: &std::path::Path, stem: &str, keep: usize) {
let prefix = format!("{}_", stem);
let mut backups: Vec<_> = std::fs::read_dir(backup_dir)
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.filter(|e| {
let name = e.file_name();
let name = name.to_string_lossy();
// Match main db backups (not -wal/-shm)
name.starts_with(&prefix) && !name.ends_with("-wal") && !name.ends_with("-shm")
})
.collect();
if backups.len() <= keep {
return;
}
// Sort by filename (timestamps sort lexicographically)
backups.sort_by_key(|e| e.file_name());
let to_remove = backups.len() - keep;
for entry in backups.into_iter().take(to_remove) {
let path = entry.path();
let name = path.file_name().unwrap_or_default().to_string_lossy().to_string();
if let Err(e) = std::fs::remove_file(&path) {
tracing::warn!("Failed to remove old backup {}: {}", name, e);
} else {
tracing::debug!("Pruned old backup: {}", name);
// Also remove associated WAL/SHM backups
let wal = path.with_extension(format!("{}-wal", path.extension().unwrap_or_default().to_string_lossy()));
let shm = path.with_extension(format!("{}-shm", path.extension().unwrap_or_default().to_string_lossy()));
let _ = std::fs::remove_file(&wal);
let _ = std::fs::remove_file(&shm);
}
}
}
pub struct AppState {
pub db: sqlx::SqlitePool,
pub jwt_secret: String,
pub openrouter_key: String,
pub brave_api_key: String,
pub tx: broadcast::Sender<models::BroadcastEvent>,
}
#[tokio::main]
async fn main() {
dotenvy::dotenv().ok();
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into());
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
let openrouter_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set");
let brave_api_key = std::env::var("BRAVE_API_KEY").expect("BRAVE_API_KEY must be set");
// Backup the database before connecting and running migrations
backup_database(&database_url);
let db = SqlitePoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await
.expect("Failed to connect to database");
// Run migrations
let migration_sql = include_str!("../migrations/001_init.sql");
sqlx::raw_sql(migration_sql)
.execute(&db)
.await
.expect("Failed to run migrations");
// Run migration 002 - soft delete
let migration_002 = include_str!("../migrations/002_soft_delete.sql");
match sqlx::raw_sql(migration_002).execute(&db).await {
Ok(_) => tracing::info!("Migration 002 applied"),
Err(e) if e.to_string().contains("duplicate column") => {
tracing::debug!("Migration 002 already applied, skipping");
}
Err(e) => panic!("Failed to run migration 002: {}", e),
}
// Run migration 003 - ai_meta on messages
let migration_003 = include_str!("../migrations/003_ai_meta.sql");
match sqlx::raw_sql(migration_003).execute(&db).await {
Ok(_) => tracing::info!("Migration 003 applied"),
Err(e) if e.to_string().contains("duplicate column") => {
tracing::debug!("Migration 003 already applied, skipping");
}
Err(e) => panic!("Failed to run migration 003: {}", e),
}
// Run migration 004 - ai_name on rooms
let migration_004 = include_str!("../migrations/004_ai_name.sql");
match sqlx::raw_sql(migration_004).execute(&db).await {
Ok(_) => tracing::info!("Migration 004 applied"),
Err(e) if e.to_string().contains("duplicate column") => {
tracing::debug!("Migration 004 already applied, skipping");
}
Err(e) => panic!("Failed to run migration 004: {}", e),
}
// Run migration 005 - avatar_url on users
let migration_005 = include_str!("../migrations/005_avatar.sql");
match sqlx::raw_sql(migration_005).execute(&db).await {
Ok(_) => tracing::info!("Migration 005 applied"),
Err(e) if e.to_string().contains("duplicate column") => {
tracing::debug!("Migration 005 already applied, skipping");
}
Err(e) => panic!("Failed to run migration 005: {}", e),
}
// Run migration 006 - image_url on messages
let migration_006 = include_str!("../migrations/006_image_url.sql");
match sqlx::raw_sql(migration_006).execute(&db).await {
Ok(_) => tracing::info!("Migration 006 applied"),
Err(e) if e.to_string().contains("duplicate column") => {
tracing::debug!("Migration 006 already applied, skipping");
}
Err(e) => panic!("Failed to run migration 006: {}", e),
}
// Run migration 007 - SHA-256 integrity hash on messages
let migration_007 = include_str!("../migrations/007_message_hash.sql");
match sqlx::raw_sql(migration_007).execute(&db).await {
Ok(_) => {
tracing::info!("Migration 007 applied, backfilling hashes for existing messages...");
// Backfill hashes for all existing messages that don't have one
let rows = sqlx::query_as::<_, (String, String, String)>(
"SELECT id, created_at, content FROM messages WHERE hash IS NULL",
)
.fetch_all(&db)
.await
.unwrap_or_default();
let count = rows.len();
for (id, created_at, content) in rows {
let hash = models::message_hash(&created_at, &content);
let _ = sqlx::query("UPDATE messages SET hash = ? WHERE id = ?")
.bind(&hash)
.bind(&id)
.execute(&db)
.await;
}
if count > 0 {
tracing::info!("Backfilled hashes for {} existing messages", count);
}
}
Err(e) if e.to_string().contains("duplicate column") => {
tracing::debug!("Migration 007 already applied, skipping");
}
Err(e) => panic!("Failed to run migration 007: {}", e),
}
tracing::info!("Database initialized");
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096);
let state = Arc::new(AppState {
db,
jwt_secret,
openrouter_key,
brave_api_key,
tx,
});
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
// Serve static files from client dist in production
let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into());
let api_routes = Router::new()
// Auth routes
.route("/api/auth/register", post(handlers::auth::register))
.route("/api/auth/login", post(handlers::auth::login))
.route("/api/auth/me", get(handlers::auth::me))
// Profile routes
.route("/api/auth/profile", put(handlers::profile::update_profile))
.route("/api/auth/avatar", post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar))
// Room routes
.route("/api/rooms", get(handlers::rooms::list_rooms).post(handlers::rooms::create_room))
.route("/api/rooms/:room_id", get(handlers::rooms::get_room).delete(handlers::rooms::delete_room))
.route("/api/rooms/:room_id/messages", get(handlers::rooms::get_messages))
.route("/api/rooms/:room_id/join", post(handlers::rooms::join_room))
.route("/api/rooms/:room_id/clear", post(handlers::rooms::clear_room))
.route("/api/messages/hash/:hash", get(handlers::rooms::resolve_message_hash))
// Upload (chat images)
.route("/api/upload", post(handlers::upload::upload_chat_image))
// Models
.route("/api/models", get(handlers::models::list_models))
// Invite routes
.route("/api/invites", post(handlers::invites::create_invite))
.route("/api/invites/:token/accept", post(handlers::invites::accept_invite))
// Uploaded files (avatars)
.nest_service("/uploads", ServeDir::new("uploads"))
// WebSocket
.route("/ws", get(handlers::ws::ws_handler))
.layer(cors)
.with_state(state);
// SPA fallback: serve static assets, fall back to index.html for client-side routing
let spa = ServeDir::new(&static_dir)
.not_found_service(ServeFile::new(format!("{}/index.html", static_dir)));
let app = api_routes.fallback_service(spa);
let addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "0.0.0.0:3001".into());
tracing::info!("Server starting on {}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}