Add /api/messages/hash/:hash endpoint that resolves a message hash to its room ID (with membership check). The client now handles both #roomId/hash and #hash formats - the latter calls the API to find which room the message belongs to, then loads it and scrolls. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
316 lines
12 KiB
Rust
316 lines
12 KiB
Rust
mod handlers;
|
|
mod middleware;
|
|
mod models;
|
|
mod services;
|
|
|
|
use axum::{
|
|
routing::{get, post, put},
|
|
Router,
|
|
};
|
|
use sqlx::sqlite::SqlitePoolOptions;
|
|
use std::sync::Arc;
|
|
use tokio::sync::broadcast;
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
use tower_http::services::{ServeDir, ServeFile};
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
|
|
|
/// Extract the file path from a SQLite DATABASE_URL like "sqlite:chat.db?mode=rwc"
|
|
fn db_file_path(database_url: &str) -> Option<String> {
|
|
let path = database_url.strip_prefix("sqlite:")?;
|
|
// Strip query params like ?mode=rwc
|
|
let path = path.split('?').next().unwrap_or(path);
|
|
Some(path.to_string())
|
|
}
|
|
|
|
/// Create a timestamped backup of the SQLite database file.
|
|
/// Backups are stored in a `backups/` directory next to the db file.
|
|
/// Only keeps the 10 most recent backups to avoid unbounded disk usage.
|
|
fn backup_database(database_url: &str) {
|
|
let Some(db_path) = db_file_path(database_url) else {
|
|
tracing::warn!("Could not parse database path from URL, skipping backup");
|
|
return;
|
|
};
|
|
|
|
let db_file = std::path::Path::new(&db_path);
|
|
if !db_file.exists() {
|
|
tracing::info!("Database file does not exist yet, skipping backup");
|
|
return;
|
|
}
|
|
|
|
// Create backups directory next to the database
|
|
let backup_dir = db_file
|
|
.parent()
|
|
.unwrap_or(std::path::Path::new("."))
|
|
.join("backups");
|
|
|
|
if let Err(e) = std::fs::create_dir_all(&backup_dir) {
|
|
tracing::error!("Failed to create backup directory: {}", e);
|
|
return;
|
|
}
|
|
|
|
// Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db
|
|
let stem = db_file
|
|
.file_stem()
|
|
.and_then(|s| s.to_str())
|
|
.unwrap_or("db");
|
|
let ext = db_file
|
|
.extension()
|
|
.and_then(|s| s.to_str())
|
|
.unwrap_or("db");
|
|
|
|
let now = chrono::Local::now();
|
|
let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext);
|
|
let backup_path = backup_dir.join(&backup_name);
|
|
|
|
match std::fs::copy(db_file, &backup_path) {
|
|
Ok(bytes) => {
|
|
tracing::info!(
|
|
"Database backup created: {} ({:.1} KB)",
|
|
backup_path.display(),
|
|
bytes as f64 / 1024.0
|
|
);
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("Failed to backup database: {}", e);
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Also copy WAL and SHM files if they exist (for consistency)
|
|
let wal_path = format!("{}-wal", db_path);
|
|
let shm_path = format!("{}-shm", db_path);
|
|
if std::path::Path::new(&wal_path).exists() {
|
|
let wal_backup = backup_dir.join(format!("{}_{}.{}-wal", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
|
|
let _ = std::fs::copy(&wal_path, &wal_backup);
|
|
}
|
|
if std::path::Path::new(&shm_path).exists() {
|
|
let shm_backup = backup_dir.join(format!("{}_{}.{}-shm", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
|
|
let _ = std::fs::copy(&shm_path, &shm_backup);
|
|
}
|
|
|
|
// Prune old backups: keep only the 10 most recent
|
|
prune_old_backups(&backup_dir, stem, 10);
|
|
}
|
|
|
|
/// Remove old backups, keeping only the `keep` most recent ones.
|
|
fn prune_old_backups(backup_dir: &std::path::Path, stem: &str, keep: usize) {
|
|
let prefix = format!("{}_", stem);
|
|
let mut backups: Vec<_> = std::fs::read_dir(backup_dir)
|
|
.into_iter()
|
|
.flatten()
|
|
.filter_map(|e| e.ok())
|
|
.filter(|e| {
|
|
let name = e.file_name();
|
|
let name = name.to_string_lossy();
|
|
// Match main db backups (not -wal/-shm)
|
|
name.starts_with(&prefix) && !name.ends_with("-wal") && !name.ends_with("-shm")
|
|
})
|
|
.collect();
|
|
|
|
if backups.len() <= keep {
|
|
return;
|
|
}
|
|
|
|
// Sort by filename (timestamps sort lexicographically)
|
|
backups.sort_by_key(|e| e.file_name());
|
|
|
|
let to_remove = backups.len() - keep;
|
|
for entry in backups.into_iter().take(to_remove) {
|
|
let path = entry.path();
|
|
let name = path.file_name().unwrap_or_default().to_string_lossy().to_string();
|
|
if let Err(e) = std::fs::remove_file(&path) {
|
|
tracing::warn!("Failed to remove old backup {}: {}", name, e);
|
|
} else {
|
|
tracing::debug!("Pruned old backup: {}", name);
|
|
// Also remove associated WAL/SHM backups
|
|
let wal = path.with_extension(format!("{}-wal", path.extension().unwrap_or_default().to_string_lossy()));
|
|
let shm = path.with_extension(format!("{}-shm", path.extension().unwrap_or_default().to_string_lossy()));
|
|
let _ = std::fs::remove_file(&wal);
|
|
let _ = std::fs::remove_file(&shm);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct AppState {
|
|
pub db: sqlx::SqlitePool,
|
|
pub jwt_secret: String,
|
|
pub openrouter_key: String,
|
|
pub brave_api_key: String,
|
|
pub tx: broadcast::Sender<models::BroadcastEvent>,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
dotenvy::dotenv().ok();
|
|
|
|
tracing_subscriber::registry()
|
|
.with(tracing_subscriber::EnvFilter::new(
|
|
std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()),
|
|
))
|
|
.with(tracing_subscriber::fmt::layer())
|
|
.init();
|
|
|
|
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into());
|
|
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
|
let openrouter_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set");
|
|
let brave_api_key = std::env::var("BRAVE_API_KEY").expect("BRAVE_API_KEY must be set");
|
|
|
|
// Backup the database before connecting and running migrations
|
|
backup_database(&database_url);
|
|
|
|
let db = SqlitePoolOptions::new()
|
|
.max_connections(5)
|
|
.connect(&database_url)
|
|
.await
|
|
.expect("Failed to connect to database");
|
|
|
|
// Run migrations
|
|
let migration_sql = include_str!("../migrations/001_init.sql");
|
|
sqlx::raw_sql(migration_sql)
|
|
.execute(&db)
|
|
.await
|
|
.expect("Failed to run migrations");
|
|
|
|
// Run migration 002 - soft delete
|
|
let migration_002 = include_str!("../migrations/002_soft_delete.sql");
|
|
match sqlx::raw_sql(migration_002).execute(&db).await {
|
|
Ok(_) => tracing::info!("Migration 002 applied"),
|
|
Err(e) if e.to_string().contains("duplicate column") => {
|
|
tracing::debug!("Migration 002 already applied, skipping");
|
|
}
|
|
Err(e) => panic!("Failed to run migration 002: {}", e),
|
|
}
|
|
|
|
// Run migration 003 - ai_meta on messages
|
|
let migration_003 = include_str!("../migrations/003_ai_meta.sql");
|
|
match sqlx::raw_sql(migration_003).execute(&db).await {
|
|
Ok(_) => tracing::info!("Migration 003 applied"),
|
|
Err(e) if e.to_string().contains("duplicate column") => {
|
|
tracing::debug!("Migration 003 already applied, skipping");
|
|
}
|
|
Err(e) => panic!("Failed to run migration 003: {}", e),
|
|
}
|
|
|
|
// Run migration 004 - ai_name on rooms
|
|
let migration_004 = include_str!("../migrations/004_ai_name.sql");
|
|
match sqlx::raw_sql(migration_004).execute(&db).await {
|
|
Ok(_) => tracing::info!("Migration 004 applied"),
|
|
Err(e) if e.to_string().contains("duplicate column") => {
|
|
tracing::debug!("Migration 004 already applied, skipping");
|
|
}
|
|
Err(e) => panic!("Failed to run migration 004: {}", e),
|
|
}
|
|
|
|
// Run migration 005 - avatar_url on users
|
|
let migration_005 = include_str!("../migrations/005_avatar.sql");
|
|
match sqlx::raw_sql(migration_005).execute(&db).await {
|
|
Ok(_) => tracing::info!("Migration 005 applied"),
|
|
Err(e) if e.to_string().contains("duplicate column") => {
|
|
tracing::debug!("Migration 005 already applied, skipping");
|
|
}
|
|
Err(e) => panic!("Failed to run migration 005: {}", e),
|
|
}
|
|
|
|
// Run migration 006 - image_url on messages
|
|
let migration_006 = include_str!("../migrations/006_image_url.sql");
|
|
match sqlx::raw_sql(migration_006).execute(&db).await {
|
|
Ok(_) => tracing::info!("Migration 006 applied"),
|
|
Err(e) if e.to_string().contains("duplicate column") => {
|
|
tracing::debug!("Migration 006 already applied, skipping");
|
|
}
|
|
Err(e) => panic!("Failed to run migration 006: {}", e),
|
|
}
|
|
|
|
// Run migration 007 - SHA-256 integrity hash on messages
|
|
let migration_007 = include_str!("../migrations/007_message_hash.sql");
|
|
match sqlx::raw_sql(migration_007).execute(&db).await {
|
|
Ok(_) => {
|
|
tracing::info!("Migration 007 applied, backfilling hashes for existing messages...");
|
|
// Backfill hashes for all existing messages that don't have one
|
|
let rows = sqlx::query_as::<_, (String, String, String)>(
|
|
"SELECT id, created_at, content FROM messages WHERE hash IS NULL",
|
|
)
|
|
.fetch_all(&db)
|
|
.await
|
|
.unwrap_or_default();
|
|
|
|
let count = rows.len();
|
|
for (id, created_at, content) in rows {
|
|
let hash = models::message_hash(&created_at, &content);
|
|
let _ = sqlx::query("UPDATE messages SET hash = ? WHERE id = ?")
|
|
.bind(&hash)
|
|
.bind(&id)
|
|
.execute(&db)
|
|
.await;
|
|
}
|
|
if count > 0 {
|
|
tracing::info!("Backfilled hashes for {} existing messages", count);
|
|
}
|
|
}
|
|
Err(e) if e.to_string().contains("duplicate column") => {
|
|
tracing::debug!("Migration 007 already applied, skipping");
|
|
}
|
|
Err(e) => panic!("Failed to run migration 007: {}", e),
|
|
}
|
|
|
|
tracing::info!("Database initialized");
|
|
|
|
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096);
|
|
|
|
let state = Arc::new(AppState {
|
|
db,
|
|
jwt_secret,
|
|
openrouter_key,
|
|
brave_api_key,
|
|
tx,
|
|
});
|
|
|
|
let cors = CorsLayer::new()
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_headers(Any);
|
|
|
|
// Serve static files from client dist in production
|
|
let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into());
|
|
|
|
let api_routes = Router::new()
|
|
// Auth routes
|
|
.route("/api/auth/register", post(handlers::auth::register))
|
|
.route("/api/auth/login", post(handlers::auth::login))
|
|
.route("/api/auth/me", get(handlers::auth::me))
|
|
// Profile routes
|
|
.route("/api/auth/profile", put(handlers::profile::update_profile))
|
|
.route("/api/auth/avatar", post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar))
|
|
// Room routes
|
|
.route("/api/rooms", get(handlers::rooms::list_rooms).post(handlers::rooms::create_room))
|
|
.route("/api/rooms/:room_id", get(handlers::rooms::get_room).delete(handlers::rooms::delete_room))
|
|
.route("/api/rooms/:room_id/messages", get(handlers::rooms::get_messages))
|
|
.route("/api/rooms/:room_id/join", post(handlers::rooms::join_room))
|
|
.route("/api/rooms/:room_id/clear", post(handlers::rooms::clear_room))
|
|
.route("/api/messages/hash/:hash", get(handlers::rooms::resolve_message_hash))
|
|
// Upload (chat images)
|
|
.route("/api/upload", post(handlers::upload::upload_chat_image))
|
|
// Models
|
|
.route("/api/models", get(handlers::models::list_models))
|
|
// Invite routes
|
|
.route("/api/invites", post(handlers::invites::create_invite))
|
|
.route("/api/invites/:token/accept", post(handlers::invites::accept_invite))
|
|
// Uploaded files (avatars)
|
|
.nest_service("/uploads", ServeDir::new("uploads"))
|
|
// WebSocket
|
|
.route("/ws", get(handlers::ws::ws_handler))
|
|
.layer(cors)
|
|
.with_state(state);
|
|
|
|
// SPA fallback: serve static assets, fall back to index.html for client-side routing
|
|
let spa = ServeDir::new(&static_dir)
|
|
.not_found_service(ServeFile::new(format!("{}/index.html", static_dir)));
|
|
|
|
let app = api_routes.fallback_service(spa);
|
|
|
|
let addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "0.0.0.0:3001".into());
|
|
tracing::info!("Server starting on {}", addr);
|
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
|
axum::serve(listener, app).await.unwrap();
|
|
}
|