Document server code paths

This commit is contained in:
Jason Tudisco 2026-03-17 15:14:04 -06:00
parent 927d106eae
commit c37ff79514
19 changed files with 480 additions and 195 deletions

View File

@ -1,8 +1,8 @@
use axum::{extract::State, http::StatusCode, Json};
use argon2::{ use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2, Argon2,
}; };
use axum::{extract::State, http::StatusCode, Json};
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
@ -12,6 +12,7 @@ use crate::{
AppState, AppState,
}; };
/// Create a new password-based account and immediately return a JWT.
pub async fn register( pub async fn register(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(body): Json<RegisterRequest>, Json(body): Json<RegisterRequest>,
@ -69,6 +70,7 @@ pub async fn register(
})) }))
} }
/// Authenticate an existing password-based account and return a fresh JWT.
pub async fn login( pub async fn login(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
Json(body): Json<LoginRequest>, Json(body): Json<LoginRequest>,
@ -87,8 +89,8 @@ pub async fn login(
let (user_id, email, display_name, hash, avatar_url) = user; let (user_id, email, display_name, hash, avatar_url) = user;
let parsed_hash = PasswordHash::new(&hash) let parsed_hash =
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; PasswordHash::new(&hash).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Argon2::default() Argon2::default()
.verify_password(body.password.as_bytes(), &parsed_hash) .verify_password(body.password.as_bytes(), &parsed_hash)
@ -109,6 +111,7 @@ pub async fn login(
})) }))
} }
/// Return the caller's current public profile information.
pub async fn me( pub async fn me(
auth: AuthUser, auth: AuthUser,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,

View File

@ -13,6 +13,7 @@ use crate::{
AppState, AppState,
}; };
/// Response payload for a newly created invite link.
#[derive(serde::Serialize)] #[derive(serde::Serialize)]
pub struct InviteResponse { pub struct InviteResponse {
pub id: String, pub id: String,
@ -20,6 +21,7 @@ pub struct InviteResponse {
pub invite_url: String, pub invite_url: String,
} }
/// Create a one-time invite token for a room member to share.
pub async fn create_invite( pub async fn create_invite(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -46,15 +48,17 @@ pub async fn create_invite(
.map(char::from) .map(char::from)
.collect(); .collect();
sqlx::query("INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)") sqlx::query(
.bind(&invite_id) "INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)",
.bind(&body.room_id) )
.bind(&auth.user_id) .bind(&invite_id)
.bind(&body.email) .bind(&body.room_id)
.bind(&token) .bind(&auth.user_id)
.execute(&state.db) .bind(&body.email)
.await .bind(&token)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .execute(&state.db)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(InviteResponse { Ok(Json(InviteResponse {
id: invite_id, id: invite_id,
@ -63,11 +67,13 @@ pub async fn create_invite(
})) }))
} }
/// Response payload returned after consuming an invite.
#[derive(serde::Serialize)] #[derive(serde::Serialize)]
pub struct AcceptInviteResponse { pub struct AcceptInviteResponse {
pub room_id: String, pub room_id: String,
} }
/// Consume an invite token and add the caller to the room.
pub async fn accept_invite( pub async fn accept_invite(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -89,13 +95,12 @@ pub async fn accept_invite(
} }
// Verify room is not deleted // Verify room is not deleted
let room_active = sqlx::query_scalar::<_, String>( let room_active =
"SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL", sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
) .bind(&room_id)
.bind(&room_id) .fetch_optional(&state.db)
.fetch_optional(&state.db) .await
.await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if room_active.is_none() { if room_active.is_none() {
return Err((StatusCode::GONE, "This room has been deleted".into())); return Err((StatusCode::GONE, "This room has been deleted".into()));
@ -119,6 +124,7 @@ pub async fn accept_invite(
Ok(Json(AcceptInviteResponse { room_id })) Ok(Json(AcceptInviteResponse { room_id }))
} }
/// Result of a Nostr-based room invite attempt.
#[derive(serde::Serialize)] #[derive(serde::Serialize)]
pub struct NostrInviteResponse { pub struct NostrInviteResponse {
pub status: String, pub status: String,
@ -126,6 +132,7 @@ pub struct NostrInviteResponse {
pub display_name: Option<String>, pub display_name: Option<String>,
} }
/// Add a user to a room by their Nostr public key if they already have an account.
pub async fn invite_by_nostr( pub async fn invite_by_nostr(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -138,8 +145,13 @@ pub async fn invite_by_nostr(
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid npub format".to_string()))? .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid npub format".to_string()))?
} else { } else {
// Validate it's valid hex // Validate it's valid hex
if body.nostr_pubkey.len() != 64 || !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit()) { if body.nostr_pubkey.len() != 64
return Err((StatusCode::BAD_REQUEST, "Invalid pubkey: must be 64-char hex or npub".to_string())); || !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit())
{
return Err((
StatusCode::BAD_REQUEST,
"Invalid pubkey: must be 64-char hex or npub".to_string(),
));
} }
body.nostr_pubkey.clone() body.nostr_pubkey.clone()
}; };

View File

@ -1,3 +1,8 @@
//! HTTP and WebSocket entry points for the server.
//!
//! Each submodule exposes route handlers that Axum wires into the router in
//! `main.rs`.
pub mod auth; pub mod auth;
pub mod invites; pub mod invites;
pub mod models; pub mod models;

View File

@ -1,19 +1,16 @@
use axum::{ use axum::{extract::State, http::StatusCode, Json};
extract::State,
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::OnceCell;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::sync::OnceCell;
use crate::AppState; use crate::AppState;
/// Cached model list with expiry. /// Cached model list with expiry.
static MODEL_CACHE: OnceCell<Mutex<CachedModels>> = OnceCell::const_new(); static MODEL_CACHE: OnceCell<Mutex<CachedModels>> = OnceCell::const_new();
/// Process-wide cache for the OpenRouter model catalog.
struct CachedModels { struct CachedModels {
models: Vec<ModelInfo>, models: Vec<ModelInfo>,
fetched_at: Instant, fetched_at: Instant,
@ -21,6 +18,7 @@ struct CachedModels {
const CACHE_TTL: Duration = Duration::from_secs(60 * 30); // 30 minutes const CACHE_TTL: Duration = Duration::from_secs(60 * 30); // 30 minutes
/// Model metadata exposed to the client for room creation and model selection.
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct ModelInfo { pub struct ModelInfo {
pub id: String, pub id: String,
@ -56,6 +54,10 @@ struct OpenRouterArchitecture {
input_modalities: Option<Vec<String>>, input_modalities: Option<Vec<String>>,
} }
/// Fetch the model catalog directly from OpenRouter.
///
/// The result is normalized into the smaller `ModelInfo` shape that the client
/// UI needs.
async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> { async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
let client = reqwest::Client::new(); let client = reqwest::Client::new();
@ -82,7 +84,8 @@ async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
.into_iter() .into_iter()
.map(|m| { .map(|m| {
let pricing = m.pricing.as_ref(); let pricing = m.pricing.as_ref();
let supports_vision = m.architecture let supports_vision = m
.architecture
.as_ref() .as_ref()
.and_then(|a| a.input_modalities.as_ref()) .and_then(|a| a.input_modalities.as_ref())
.map(|mods| mods.iter().any(|m| m == "image")) .map(|mods| mods.iter().any(|m| m == "image"))
@ -102,6 +105,7 @@ async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
Ok(models) Ok(models)
} }
/// Return the cached OpenRouter model list, refreshing it when the cache expires.
pub async fn list_models( pub async fn list_models(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
) -> Result<Json<Vec<ModelInfo>>, (StatusCode, String)> { ) -> Result<Json<Vec<ModelInfo>>, (StatusCode, String)> {

View File

@ -10,6 +10,7 @@ use crate::{
AppState, AppState,
}; };
/// Claims embedded in the short-lived challenge token used during Nostr login.
#[derive(Debug, serde::Serialize, serde::Deserialize)] #[derive(Debug, serde::Serialize, serde::Deserialize)]
struct ChallengeClaims { struct ChallengeClaims {
pub nonce: String, pub nonce: String,
@ -28,10 +29,7 @@ pub async fn challenge(
let exp = (chrono::Utc::now().timestamp() + 120) as usize; // 2 minutes let exp = (chrono::Utc::now().timestamp() + 120) as usize; // 2 minutes
let claims = ChallengeClaims { let claims = ChallengeClaims { nonce, exp };
nonce,
exp,
};
let token = encode( let token = encode(
&Header::default(), &Header::default(),
@ -45,6 +43,7 @@ pub async fn challenge(
/// Simple hex encoder (avoid adding the `hex` crate just for this) /// Simple hex encoder (avoid adding the `hex` crate just for this)
mod hex { mod hex {
/// Convert raw bytes into a lowercase hexadecimal string.
pub fn encode(bytes: &[u8]) -> String { pub fn encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{:02x}", b)).collect() bytes.iter().map(|b| format!("{:02x}", b)).collect()
} }
@ -61,17 +60,29 @@ pub async fn verify(
&DecodingKey::from_secret(state.jwt_secret.as_bytes()), &DecodingKey::from_secret(state.jwt_secret.as_bytes()),
&Validation::default(), &Validation::default(),
) )
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid or expired challenge".to_string()))?; .map_err(|_| {
(
StatusCode::BAD_REQUEST,
"Invalid or expired challenge".to_string(),
)
})?;
let nonce = &challenge_data.claims.nonce; let nonce = &challenge_data.claims.nonce;
// 2. Deserialize signed_event as nostr::Event // 2. Deserialize signed_event as nostr::Event
let event: Event = serde_json::from_str(&body.signed_event) let event: Event = serde_json::from_str(&body.signed_event).map_err(|e| {
.map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid event JSON: {}", e)))?; (
StatusCode::BAD_REQUEST,
format!("Invalid event JSON: {}", e),
)
})?;
// 3. Verify Schnorr signature // 3. Verify Schnorr signature
if !event.verify_signature() { if !event.verify_signature() {
return Err((StatusCode::UNAUTHORIZED, "Invalid event signature".to_string())); return Err((
StatusCode::UNAUTHORIZED,
"Invalid event signature".to_string(),
));
} }
// 4. Verify event.content == nonce // 4. Verify event.content == nonce
@ -83,7 +94,10 @@ pub async fn verify(
let now = chrono::Utc::now().timestamp() as u64; let now = chrono::Utc::now().timestamp() as u64;
let event_ts = event.created_at.as_secs(); let event_ts = event.created_at.as_secs();
if now.abs_diff(event_ts) > 300 { if now.abs_diff(event_ts) > 300 {
return Err((StatusCode::BAD_REQUEST, "Event timestamp too far off".to_string())); return Err((
StatusCode::BAD_REQUEST,
"Event timestamp too far off".to_string(),
));
} }
// 6. Extract pubkey hex // 6. Extract pubkey hex

View File

@ -11,6 +11,7 @@ use crate::{
AppState, AppState,
}; };
/// Request body for profile updates.
#[derive(Debug, serde::Deserialize)] #[derive(Debug, serde::Deserialize)]
pub struct UpdateProfileRequest { pub struct UpdateProfileRequest {
pub display_name: Option<String>, pub display_name: Option<String>,
@ -25,7 +26,10 @@ pub async fn update_profile(
let display_name = body.display_name.unwrap_or(auth.display_name.clone()); let display_name = body.display_name.unwrap_or(auth.display_name.clone());
if display_name.trim().is_empty() { if display_name.trim().is_empty() {
return Err((StatusCode::BAD_REQUEST, "Display name cannot be empty".into())); return Err((
StatusCode::BAD_REQUEST,
"Display name cannot be empty".into(),
));
} }
sqlx::query("UPDATE users SET display_name = ? WHERE id = ?") sqlx::query("UPDATE users SET display_name = ? WHERE id = ?")
@ -83,7 +87,12 @@ pub async fn upload_avatar(
"image/jpeg" | "image/jpg" => "jpg", "image/jpeg" | "image/jpg" => "jpg",
"image/gif" => "gif", "image/gif" => "gif",
"image/webp" => "webp", "image/webp" => "webp",
_ => return Err((StatusCode::BAD_REQUEST, "Only PNG, JPG, GIF, and WebP images are allowed".into())), _ => {
return Err((
StatusCode::BAD_REQUEST,
"Only PNG, JPG, GIF, and WebP images are allowed".into(),
))
}
}; };
let data = field let data = field
@ -130,8 +139,13 @@ pub async fn upload_avatar(
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// Issue new token // Issue new token
let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret) let token = create_token(
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; &auth.user_id,
&auth.email,
&auth.display_name,
&state.jwt_secret,
)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(AuthResponse { Ok(Json(AuthResponse {
token, token,
@ -168,8 +182,13 @@ pub async fn delete_avatar(
.await .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret) let token = create_token(
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; &auth.user_id,
&auth.email,
&auth.display_name,
&state.jwt_secret,
)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(AuthResponse { Ok(Json(AuthResponse {
token, token,

View File

@ -8,10 +8,13 @@ use uuid::Uuid;
use crate::{ use crate::{
middleware::auth::AuthUser, middleware::auth::AuthUser,
models::{self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic}, models::{
self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic,
},
AppState, AppState,
}; };
/// Create a room, persist it, and add the creator as the first member.
pub async fn create_room( pub async fn create_room(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -60,6 +63,7 @@ pub async fn create_room(
})) }))
} }
/// List all active rooms the caller belongs to, including current room members.
pub async fn list_rooms( pub async fn list_rooms(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -93,13 +97,15 @@ pub async fn list_rooms(
created_at: room.created_at, created_at: room.created_at,
members: members members: members
.into_iter() .into_iter()
.map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic { .map(
id, |(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
email: models::public_email(&email), id,
display_name, email: models::public_email(&email),
avatar_url, display_name,
nostr_pubkey, avatar_url,
}) nostr_pubkey,
},
)
.collect(), .collect(),
}); });
} }
@ -107,6 +113,7 @@ pub async fn list_rooms(
Ok(Json(result)) Ok(Json(result))
} }
/// Return details for a single room after verifying the caller is a member.
pub async fn get_room( pub async fn get_room(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -152,17 +159,20 @@ pub async fn get_room(
created_at: room.created_at, created_at: room.created_at,
members: members members: members
.into_iter() .into_iter()
.map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic { .map(
id, |(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
email: models::public_email(&email), id,
display_name, email: models::public_email(&email),
avatar_url, display_name,
nostr_pubkey, avatar_url,
}) nostr_pubkey,
},
)
.collect(), .collect(),
})) }))
} }
/// Return paginated message history for a room the caller can access.
pub async fn get_messages( pub async fn get_messages(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -208,37 +218,56 @@ pub async fn get_messages(
} }
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
// The SQL query reads newest-first for efficient pagination, but clients
// render chat oldest-to-newest, so reverse the rows before serializing.
let payloads: Vec<MessagePayload> = rows let payloads: Vec<MessagePayload> = rows
.into_iter() .into_iter()
.rev() .rev()
.map(|(id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at, ai_meta_str, image_url, email, avatar_url, hash)| { .map(
let ai_meta = ai_meta_str |(
.as_deref()
.and_then(|s| serde_json::from_str::<crate::models::AiMeta>(s).ok());
let avatar_hash = email
.map(|e| crate::models::gravatar_hash(&e))
.unwrap_or_default();
MessagePayload {
id, id,
room_id, room_id,
sender_id, sender_id,
sender_name, sender_name,
content, content,
mentions: serde_json::from_str(&mentions).unwrap_or_default(), mentions,
is_ai, is_ai,
created_at, created_at,
ai_meta, ai_meta_str,
avatar_hash,
avatar_url,
image_url, image_url,
email,
avatar_url,
hash, hash,
} )| {
}) let ai_meta = ai_meta_str
.as_deref()
.and_then(|s| serde_json::from_str::<crate::models::AiMeta>(s).ok());
let avatar_hash = email
.map(|e| crate::models::gravatar_hash(&e))
.unwrap_or_default();
MessagePayload {
id,
room_id,
sender_id,
sender_name,
content,
mentions: serde_json::from_str(&mentions).unwrap_or_default(),
is_ai,
created_at,
ai_meta,
avatar_hash,
avatar_url,
image_url,
hash,
}
},
)
.collect(); .collect();
Ok(Json(payloads)) Ok(Json(payloads))
} }
/// Resolve a stable message hash into the room that contains it.
pub async fn resolve_message_hash( pub async fn resolve_message_hash(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -258,22 +287,29 @@ pub async fn resolve_message_hash(
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
match row { match row {
Some((room_id,)) => Ok(Json(serde_json::json!({ "room_id": room_id, "hash": hash }))), Some((room_id,)) => Ok(Json(
None => Err((StatusCode::NOT_FOUND, "Message not found or no access".into())), serde_json::json!({ "room_id": room_id, "hash": hash }),
)),
None => Err((
StatusCode::NOT_FOUND,
"Message not found or no access".into(),
)),
} }
} }
/// Add the caller to a room directly when they already know its ID.
pub async fn join_room( pub async fn join_room(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
Path(room_id): Path<String>, Path(room_id): Path<String>,
) -> Result<StatusCode, (StatusCode, String)> { ) -> Result<StatusCode, (StatusCode, String)> {
// Check room exists // Check room exists
let room_exists = sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL") let room_exists =
.bind(&room_id) sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
.fetch_optional(&state.db) .bind(&room_id)
.await .fetch_optional(&state.db)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; .await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
if room_exists.is_none() { if room_exists.is_none() {
return Err((StatusCode::NOT_FOUND, "Room not found".into())); return Err((StatusCode::NOT_FOUND, "Room not found".into()));
@ -289,6 +325,7 @@ pub async fn join_room(
Ok(StatusCode::OK) Ok(StatusCode::OK)
} }
/// Soft-delete a room and broadcast the deletion event to connected members.
pub async fn delete_room( pub async fn delete_room(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -303,7 +340,10 @@ pub async fn delete_room(
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?; .ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
if room.created_by != auth.user_id { if room.created_by != auth.user_id {
return Err((StatusCode::FORBIDDEN, "Only the room creator can delete this room".into())); return Err((
StatusCode::FORBIDDEN,
"Only the room creator can delete this room".into(),
));
} }
// Soft-delete // Soft-delete
@ -324,6 +364,7 @@ pub async fn delete_room(
Ok(StatusCode::OK) Ok(StatusCode::OK)
} }
/// Permanently remove all messages from a room without deleting the room itself.
pub async fn clear_room( pub async fn clear_room(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
auth: AuthUser, auth: AuthUser,
@ -338,7 +379,10 @@ pub async fn clear_room(
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?; .ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
if room.created_by != auth.user_id { if room.created_by != auth.user_id {
return Err((StatusCode::FORBIDDEN, "Only the room creator can clear messages".into())); return Err((
StatusCode::FORBIDDEN,
"Only the room creator can clear messages".into(),
));
} }
// Hard-delete all messages // Hard-delete all messages

View File

@ -1,18 +1,16 @@
use axum::{ use axum::{extract::Multipart, http::StatusCode, Json};
extract::Multipart,
http::StatusCode,
Json,
};
use serde::Serialize; use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
use crate::middleware::auth::AuthUser; use crate::middleware::auth::AuthUser;
/// Response returned after a chat image upload succeeds.
#[derive(Serialize)] #[derive(Serialize)]
pub struct UploadResponse { pub struct UploadResponse {
pub url: String, pub url: String,
} }
/// Accept a multipart chat image upload and store it under `uploads/chat-images`.
pub async fn upload_chat_image( pub async fn upload_chat_image(
_auth: AuthUser, _auth: AuthUser,
mut multipart: Multipart, mut multipart: Multipart,

View File

@ -1,3 +1,9 @@
//! WebSocket workflow for live chat delivery and AI responses.
//!
//! This module does two jobs:
//! - fan out database-backed room events to subscribed browser sockets
//! - turn incoming user chat messages into stored messages and optional AI replies
use axum::{ use axum::{
extract::{ extract::{
ws::{Message, WebSocket}, ws::{Message, WebSocket},
@ -24,6 +30,7 @@ pub struct WsQuery {
token: String, token: String,
} }
/// Upgrade an authenticated request into a WebSocket connection.
pub async fn ws_handler( pub async fn ws_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
@ -37,10 +44,19 @@ pub async fn ws_handler(
} }
}; };
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.sub, claims.display_name, claims.email)) ws.on_upgrade(move |socket| {
handle_socket(socket, state, claims.sub, claims.display_name, claims.email)
})
} }
async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String, display_name: String, email: String) { /// Drive a single WebSocket connection until either the send or receive side ends.
async fn handle_socket(
socket: WebSocket,
state: Arc<AppState>,
user_id: String,
display_name: String,
email: String,
) {
let (mut ws_tx, mut ws_rx) = socket.split(); let (mut ws_tx, mut ws_rx) = socket.split();
let mut broadcast_rx = state.tx.subscribe(); let mut broadcast_rx = state.tx.subscribe();
@ -50,7 +66,8 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
let rooms_clone = subscribed_rooms.clone(); let rooms_clone = subscribed_rooms.clone();
// Task: forward broadcast events to this client // Task 1: forward room events from the shared broadcast channel into this
// specific socket, but only for rooms the browser subscribed to.
let mut send_task = tokio::spawn(async move { let mut send_task = tokio::spawn(async move {
loop { loop {
match broadcast_rx.recv().await { match broadcast_rx.recv().await {
@ -81,7 +98,8 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
let email_clone = email.clone(); let email_clone = email.clone();
let rooms_clone2 = subscribed_rooms.clone(); let rooms_clone2 = subscribed_rooms.clone();
// Task: receive messages from client // Task 2: receive commands from the browser and translate them into
// database writes, broadcasts, or AI work.
let mut recv_task = tokio::spawn(async move { let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = ws_rx.next().await { while let Some(Ok(msg)) = ws_rx.next().await {
let text = match msg { let text = match msg {
@ -141,7 +159,7 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
} }
}); });
// Wait for either task to finish, then abort the other // If either half of the connection ends, stop the companion task too.
tokio::select! { tokio::select! {
_ = &mut send_task => recv_task.abort(), _ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(), _ = &mut recv_task => send_task.abort(),
@ -150,6 +168,7 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
tracing::info!("WebSocket disconnected: {}", user_id); tracing::info!("WebSocket disconnected: {}", user_id);
} }
/// Persist a user message, broadcast it, and optionally generate an AI reply.
async fn handle_send_message( async fn handle_send_message(
state: &Arc<AppState>, state: &Arc<AppState>,
user_id: &str, user_id: &str,
@ -184,13 +203,14 @@ async fn handle_send_message(
.await; .await;
// Look up the sender's custom avatar (if any) for the message payload // Look up the sender's custom avatar (if any) for the message payload
let avatar_url: Option<String> = sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?") let avatar_url: Option<String> =
.bind(user_id) sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?")
.fetch_optional(&state.db) .bind(user_id)
.await .fetch_optional(&state.db)
.ok() .await
.flatten() .ok()
.flatten(); .flatten()
.flatten();
// Broadcast human message // Broadcast human message
let payload = MessagePayload { let payload = MessagePayload {
@ -211,12 +231,11 @@ async fn handle_send_message(
let _ = state.tx.send(BroadcastEvent { let _ = state.tx.send(BroadcastEvent {
room_id: room_id.to_string(), room_id: room_id.to_string(),
message: WsServerMessage::NewMessage { message: WsServerMessage::NewMessage { message: payload },
message: payload,
},
}); });
// Check if AI should respond // The AI only replies when explicitly mentioned or when the room is set to
// auto-reply to every message.
let ai_user_id = "ai-assistant"; let ai_user_id = "ai-assistant";
let should_respond = mentions.contains(&ai_user_id.to_string()); let should_respond = mentions.contains(&ai_user_id.to_string());
@ -254,7 +273,8 @@ async fn handle_send_message(
.await .await
.unwrap_or_default(); .unwrap_or_default();
// Process history: encode images as base64 data URLs for OpenRouter // OpenRouter accepts image inputs as data URLs, so local uploads need to be
// loaded from disk and encoded before they are sent upstream.
let mut history: Vec<(String, String, bool, Option<String>)> = Vec::new(); let mut history: Vec<(String, String, bool, Option<String>)> = Vec::new();
for (sender_name, msg_content, is_ai, msg_image_url) in recent_messages.into_iter().rev() { for (sender_name, msg_content, is_ai, msg_image_url) in recent_messages.into_iter().rev() {
let image_data_url = match &msg_image_url { let image_data_url = match &msg_image_url {
@ -272,7 +292,8 @@ async fn handle_send_message(
// Pre-generate AI message ID so we can reference it in stream chunks // Pre-generate AI message ID so we can reference it in stream chunks
let ai_msg_id = Uuid::new_v4().to_string(); let ai_msg_id = Uuid::new_v4().to_string();
// Call OpenRouter with tool loop — uses streaming for all rounds // Run the AI in a loop because the model may first request tools, then need
// follow-up rounds after those tool results are added to history.
let mut total_prompt_tokens: u32 = 0; let mut total_prompt_tokens: u32 = 0;
let mut total_completion_tokens: u32 = 0; let mut total_completion_tokens: u32 = 0;
let mut total_response_ms: u64 = 0; let mut total_response_ms: u64 = 0;
@ -313,16 +334,24 @@ async fn handle_send_message(
tracing::info!( tracing::info!(
"AI requesting tool calls (round {}): {:?}", "AI requesting tool calls (round {}): {:?}",
round + 1, round + 1,
assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>()) assistant_msg
.tool_calls
.as_ref()
.map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
); );
// Add the assistant's tool-call message to history // Preserve the assistant tool-call message so the next round
// has the same context the model produced.
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default(); let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
chat_history.push(assistant_msg); chat_history.push(assistant_msg);
// Execute each tool call and add results // Tool results are fed back into the conversation as
// synthetic `tool` messages, matching the upstream API.
for tool_call in &tool_calls { for tool_call in &tool_calls {
let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments); let tool_input = extract_tool_input(
&tool_call.function.name,
&tool_call.function.arguments,
);
// Broadcast real-time tool usage event // Broadcast real-time tool usage event
let _ = state.tx.send(BroadcastEvent { let _ = state.tx.send(BroadcastEvent {
@ -362,7 +391,7 @@ async fn handle_send_message(
tool_call_id: Some(tool_call.id.clone()), tool_call_id: Some(tool_call.id.clone()),
}); });
} }
// Continue to next round (tool loop) // Ask the model to continue now that tool output exists.
continue 'tool_loop; continue 'tool_loop;
} }
openrouter::StreamEvent::Done(stats) => { openrouter::StreamEvent::Done(stats) => {
@ -382,9 +411,12 @@ async fn handle_send_message(
} }
} }
// If we exhausted all rounds without a text response, note it // Guardrail: if the model never produced final prose, store a clear fallback
// instead of leaving the client waiting indefinitely.
if ai_response.is_empty() && !had_error { if ai_response.is_empty() && !had_error {
ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string(); ai_response =
"*I used several tools but couldn't formulate a final response. Please try again.*"
.to_string();
} }
// Signal stream end so client can finalize rendering // Signal stream end so client can finalize rendering
@ -512,7 +544,15 @@ async fn execute_tool(
return "Error: search query is required".into(); return "Error: search query is required".into();
} }
match search::search(search_provider, &query, tavily_api_key, brave_api_key, count).await { match search::search(
search_provider,
&query,
tavily_api_key,
brave_api_key,
count,
)
.await
{
Ok(results) => search::format_results(&results), Ok(results) => search::format_results(&results),
Err(e) => format!("Search error: {}", e), Err(e) => format!("Search error: {}", e),
} }

View File

@ -1,3 +1,12 @@
//! Application bootstrap for the GroupChat server.
//!
//! This file is responsible for:
//! - loading environment configuration
//! - opening and migrating the SQLite database
//! - constructing shared application state
//! - registering HTTP/WebSocket routes
//! - serving the SPA frontend in production
mod handlers; mod handlers;
mod middleware; mod middleware;
mod models; mod models;
@ -51,14 +60,8 @@ fn backup_database(database_url: &str) {
} }
// Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db // Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db
let stem = db_file let stem = db_file.file_stem().and_then(|s| s.to_str()).unwrap_or("db");
.file_stem() let ext = db_file.extension().and_then(|s| s.to_str()).unwrap_or("db");
.and_then(|s| s.to_str())
.unwrap_or("db");
let ext = db_file
.extension()
.and_then(|s| s.to_str())
.unwrap_or("db");
let now = chrono::Local::now(); let now = chrono::Local::now();
let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext); let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext);
@ -82,11 +85,21 @@ fn backup_database(database_url: &str) {
let wal_path = format!("{}-wal", db_path); let wal_path = format!("{}-wal", db_path);
let shm_path = format!("{}-shm", db_path); let shm_path = format!("{}-shm", db_path);
if std::path::Path::new(&wal_path).exists() { if std::path::Path::new(&wal_path).exists() {
let wal_backup = backup_dir.join(format!("{}_{}.{}-wal", stem, now.format("%Y-%m-%d_%H%M%S"), ext)); let wal_backup = backup_dir.join(format!(
"{}_{}.{}-wal",
stem,
now.format("%Y-%m-%d_%H%M%S"),
ext
));
let _ = std::fs::copy(&wal_path, &wal_backup); let _ = std::fs::copy(&wal_path, &wal_backup);
} }
if std::path::Path::new(&shm_path).exists() { if std::path::Path::new(&shm_path).exists() {
let shm_backup = backup_dir.join(format!("{}_{}.{}-shm", stem, now.format("%Y-%m-%d_%H%M%S"), ext)); let shm_backup = backup_dir.join(format!(
"{}_{}.{}-shm",
stem,
now.format("%Y-%m-%d_%H%M%S"),
ext
));
let _ = std::fs::copy(&shm_path, &shm_backup); let _ = std::fs::copy(&shm_path, &shm_backup);
} }
@ -119,20 +132,35 @@ fn prune_old_backups(backup_dir: &std::path::Path, stem: &str, keep: usize) {
let to_remove = backups.len() - keep; let to_remove = backups.len() - keep;
for entry in backups.into_iter().take(to_remove) { for entry in backups.into_iter().take(to_remove) {
let path = entry.path(); let path = entry.path();
let name = path.file_name().unwrap_or_default().to_string_lossy().to_string(); let name = path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string();
if let Err(e) = std::fs::remove_file(&path) { if let Err(e) = std::fs::remove_file(&path) {
tracing::warn!("Failed to remove old backup {}: {}", name, e); tracing::warn!("Failed to remove old backup {}: {}", name, e);
} else { } else {
tracing::debug!("Pruned old backup: {}", name); tracing::debug!("Pruned old backup: {}", name);
// Also remove associated WAL/SHM backups // Also remove associated WAL/SHM backups
let wal = path.with_extension(format!("{}-wal", path.extension().unwrap_or_default().to_string_lossy())); let wal = path.with_extension(format!(
let shm = path.with_extension(format!("{}-shm", path.extension().unwrap_or_default().to_string_lossy())); "{}-wal",
path.extension().unwrap_or_default().to_string_lossy()
));
let shm = path.with_extension(format!(
"{}-shm",
path.extension().unwrap_or_default().to_string_lossy()
));
let _ = std::fs::remove_file(&wal); let _ = std::fs::remove_file(&wal);
let _ = std::fs::remove_file(&shm); let _ = std::fs::remove_file(&shm);
} }
} }
} }
/// Shared state injected into every handler.
///
/// Axum stores this behind an `Arc`, so handlers can cheaply clone the pointer
/// while all requests still talk to the same database pool, API keys, and
/// broadcast channel.
pub struct AppState { pub struct AppState {
pub db: sqlx::SqlitePool, pub db: sqlx::SqlitePool,
pub jwt_secret: String, pub jwt_secret: String,
@ -154,11 +182,15 @@ async fn main() {
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into()); // Load the runtime configuration needed to start the server.
let database_url =
std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into());
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into()); let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
let openrouter_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); let openrouter_key =
let search_provider = SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref()) std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set");
.unwrap_or_else(|e| panic!("{}", e)); let search_provider =
SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref())
.unwrap_or_else(|e| panic!("{}", e));
let tavily_api_key = std::env::var("TAVILY_API_KEY").ok(); let tavily_api_key = std::env::var("TAVILY_API_KEY").ok();
let brave_api_key = std::env::var("BRAVE_API_KEY").ok(); let brave_api_key = std::env::var("BRAVE_API_KEY").ok();
@ -181,7 +213,8 @@ async fn main() {
.await .await
.expect("Failed to connect to database"); .expect("Failed to connect to database");
// Run migrations // Run migrations in order. Each one is written so startup can safely try it
// again and skip work that already happened in an earlier run.
let migration_sql = include_str!("../migrations/001_init.sql"); let migration_sql = include_str!("../migrations/001_init.sql");
sqlx::raw_sql(migration_sql) sqlx::raw_sql(migration_sql)
.execute(&db) .execute(&db)
@ -282,6 +315,8 @@ async fn main() {
tracing::info!("Database initialized"); tracing::info!("Database initialized");
// WebSocket tasks subscribe to this channel to receive room events without
// polling the database.
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096); let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096);
let state = Arc::new(AppState { let state = Arc::new(AppState {
@ -302,32 +337,61 @@ async fn main() {
// Serve static files from client dist in production // Serve static files from client dist in production
let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into()); let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into());
// Keep API routes separate from the static-file fallback so `/api/*` and
// `/ws` requests never get mistaken for SPA routes.
let api_routes = Router::new() let api_routes = Router::new()
// Auth routes // Auth routes
.route("/api/auth/register", post(handlers::auth::register)) .route("/api/auth/register", post(handlers::auth::register))
.route("/api/auth/login", post(handlers::auth::login)) .route("/api/auth/login", post(handlers::auth::login))
.route("/api/auth/me", get(handlers::auth::me)) .route("/api/auth/me", get(handlers::auth::me))
// Nostr auth routes // Nostr auth routes
.route("/api/auth/nostr/challenge", get(handlers::nostr_auth::challenge)) .route(
"/api/auth/nostr/challenge",
get(handlers::nostr_auth::challenge),
)
.route("/api/auth/nostr/verify", post(handlers::nostr_auth::verify)) .route("/api/auth/nostr/verify", post(handlers::nostr_auth::verify))
// Profile routes // Profile routes
.route("/api/auth/profile", put(handlers::profile::update_profile)) .route("/api/auth/profile", put(handlers::profile::update_profile))
.route("/api/auth/avatar", post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar)) .route(
"/api/auth/avatar",
post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar),
)
// Room routes // Room routes
.route("/api/rooms", get(handlers::rooms::list_rooms).post(handlers::rooms::create_room)) .route(
.route("/api/rooms/:room_id", get(handlers::rooms::get_room).delete(handlers::rooms::delete_room)) "/api/rooms",
.route("/api/rooms/:room_id/messages", get(handlers::rooms::get_messages)) get(handlers::rooms::list_rooms).post(handlers::rooms::create_room),
)
.route(
"/api/rooms/:room_id",
get(handlers::rooms::get_room).delete(handlers::rooms::delete_room),
)
.route(
"/api/rooms/:room_id/messages",
get(handlers::rooms::get_messages),
)
.route("/api/rooms/:room_id/join", post(handlers::rooms::join_room)) .route("/api/rooms/:room_id/join", post(handlers::rooms::join_room))
.route("/api/rooms/:room_id/clear", post(handlers::rooms::clear_room)) .route(
.route("/api/messages/hash/:hash", get(handlers::rooms::resolve_message_hash)) "/api/rooms/:room_id/clear",
post(handlers::rooms::clear_room),
)
.route(
"/api/messages/hash/:hash",
get(handlers::rooms::resolve_message_hash),
)
// Upload (chat images) // Upload (chat images)
.route("/api/upload", post(handlers::upload::upload_chat_image)) .route("/api/upload", post(handlers::upload::upload_chat_image))
// Models // Models
.route("/api/models", get(handlers::models::list_models)) .route("/api/models", get(handlers::models::list_models))
// Invite routes // Invite routes
.route("/api/invites", post(handlers::invites::create_invite)) .route("/api/invites", post(handlers::invites::create_invite))
.route("/api/invites/:token/accept", post(handlers::invites::accept_invite)) .route(
.route("/api/invites/nostr", post(handlers::invites::invite_by_nostr)) "/api/invites/:token/accept",
post(handlers::invites::accept_invite),
)
.route(
"/api/invites/nostr",
post(handlers::invites::invite_by_nostr),
)
// Uploaded files (avatars) // Uploaded files (avatars)
.nest_service("/uploads", ServeDir::new("uploads")) .nest_service("/uploads", ServeDir::new("uploads"))
// WebSocket // WebSocket

View File

@ -1,14 +1,11 @@
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{extract::FromRequestParts, http::request::Parts};
extract::FromRequestParts,
http::request::Parts,
};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use std::sync::Arc; use std::sync::Arc;
use crate::{models::Claims, AppState}; use crate::{models::Claims, AppState};
/// Extract authenticated user from JWT in Authorization header /// Authenticated user information extracted from the bearer token.
pub struct AuthUser { pub struct AuthUser {
pub user_id: String, pub user_id: String,
pub email: String, pub email: String,
@ -19,7 +16,15 @@ pub struct AuthUser {
impl FromRequestParts<Arc<AppState>> for AuthUser { impl FromRequestParts<Arc<AppState>> for AuthUser {
type Rejection = axum::http::StatusCode; type Rejection = axum::http::StatusCode;
async fn from_request_parts(parts: &mut Parts, state: &Arc<AppState>) -> Result<Self, Self::Rejection> { /// Read the `Authorization: Bearer <token>` header and decode the JWT.
///
/// Axum runs this automatically for any handler parameter of type
/// `AuthUser`, which keeps individual handlers free from repeated token
/// parsing logic.
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<AppState>,
) -> Result<Self, Self::Rejection> {
let auth_header = parts let auth_header = parts
.headers .headers
.get("Authorization") .get("Authorization")
@ -41,7 +46,16 @@ impl FromRequestParts<Arc<AppState>> for AuthUser {
} }
} }
pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str) -> Result<String, jsonwebtoken::errors::Error> { /// Create a signed JWT for a logged-in user.
///
/// The token expires after seven days and carries the small amount of identity
/// data the server wants available on every request.
pub fn create_token(
user_id: &str,
email: &str,
display_name: &str,
secret: &str,
) -> Result<String, jsonwebtoken::errors::Error> {
let expiration = chrono::Utc::now() let expiration = chrono::Utc::now()
.checked_add_signed(chrono::Duration::days(7)) .checked_add_signed(chrono::Duration::days(7))
.unwrap() .unwrap()
@ -61,6 +75,7 @@ pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str
) )
} }
/// Decode and validate a previously issued JWT.
pub fn decode_token(token: &str, secret: &str) -> Result<Claims, jsonwebtoken::errors::Error> { pub fn decode_token(token: &str, secret: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
let token_data = decode::<Claims>( let token_data = decode::<Claims>(
token, token,

View File

@ -1 +1,3 @@
//! Reusable request-processing layers shared across handlers.
pub mod auth; pub mod auth;

View File

@ -1,7 +1,14 @@
//! Core data structures shared across the server.
//!
//! This file intentionally mixes database row types, HTTP payloads, WebSocket
//! payloads, and a few helper functions so the rest of the codebase can import
//! common shapes from one place.
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
// ── Database models ── // ── Database models ──
/// Row from the `users` table.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct User { pub struct User {
pub id: String, pub id: String,
@ -11,6 +18,7 @@ pub struct User {
pub created_at: String, pub created_at: String,
} }
/// Row from the `rooms` table.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Room { pub struct Room {
pub id: String, pub id: String,
@ -24,6 +32,7 @@ pub struct Room {
pub deleted_at: Option<String>, pub deleted_at: Option<String>,
} }
/// Row from the `messages` table.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Message { pub struct Message {
pub id: String, pub id: String,
@ -38,6 +47,7 @@ pub struct Message {
pub hash: Option<String>, pub hash: Option<String>,
} }
/// Row from the `invites` table.
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Invite { pub struct Invite {
pub id: String, pub id: String,
@ -51,6 +61,7 @@ pub struct Invite {
// ── API request/response types ── // ── API request/response types ──
/// JSON body expected by the registration endpoint.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct RegisterRequest { pub struct RegisterRequest {
pub email: String, pub email: String,
@ -58,18 +69,21 @@ pub struct RegisterRequest {
pub display_name: String, pub display_name: String,
} }
/// JSON body expected by the login endpoint.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct LoginRequest { pub struct LoginRequest {
pub email: String, pub email: String,
pub password: String, pub password: String,
} }
/// Standard auth response returned after login, registration, or profile update.
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct AuthResponse { pub struct AuthResponse {
pub token: String, pub token: String,
pub user: UserPublic, pub user: UserPublic,
} }
/// Public user data safe to return to any authenticated client.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserPublic { pub struct UserPublic {
pub id: String, pub id: String,
@ -81,6 +95,7 @@ pub struct UserPublic {
pub nostr_pubkey: Option<String>, pub nostr_pubkey: Option<String>,
} }
/// JSON body used when a user creates a new chat room.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct CreateRoomRequest { pub struct CreateRoomRequest {
pub name: String, pub name: String,
@ -93,10 +108,10 @@ pub struct CreateRoomRequest {
pub ai_name: String, pub ai_name: String,
} }
/// Pick a friendly default AI display name when the creator does not specify one.
fn default_ai_name() -> String { fn default_ai_name() -> String {
let names = [ let names = [
"Nova", "Atlas", "Sage", "Echo", "Pixel", "Nova", "Atlas", "Sage", "Echo", "Pixel", "Cosmo", "Ember", "Flux", "Lyra", "Onyx",
"Cosmo", "Ember", "Flux", "Lyra", "Onyx",
]; ];
let idx = std::time::SystemTime::now() let idx = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
@ -105,10 +120,12 @@ fn default_ai_name() -> String {
names[idx].to_string() names[idx].to_string()
} }
/// Default prompt that defines the AI assistant's behavior inside a room.
fn default_system_prompt() -> String { fn default_system_prompt() -> String {
"You are a helpful AI assistant participating in a group chat. Be conversational, helpful, and concise. You can see messages from all participants. When mentioned with @ai, respond helpfully.\n\nYou have access to tools:\n- **web_search**: Search the web for current information. Use this when asked about recent events, news, facts you're unsure about, or anything that needs up-to-date information.\n- **web_fetch**: Fetch and read the content of a web page. Use this when a user shares a URL and wants you to read/summarize it, or when you need more details from a search result.\n\nUse tools proactively when they would help answer the question better. You don't need to ask permission to use them.".to_string() "You are a helpful AI assistant participating in a group chat. Be conversational, helpful, and concise. You can see messages from all participants. When mentioned with @ai, respond helpfully.\n\nYou have access to tools:\n- **web_search**: Search the web for current information. Use this when asked about recent events, news, facts you're unsure about, or anything that needs up-to-date information.\n- **web_fetch**: Fetch and read the content of a web page. Use this when a user shares a URL and wants you to read/summarize it, or when you need more details from a search result.\n\nUse tools proactively when they would help answer the question better. You don't need to ask permission to use them.".to_string()
} }
/// Full room payload returned to the client, including current members.
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct RoomResponse { pub struct RoomResponse {
pub id: String, pub id: String,
@ -122,6 +139,7 @@ pub struct RoomResponse {
pub members: Vec<UserPublic>, pub members: Vec<UserPublic>,
} }
/// JSON body for an email-based room invite.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct CreateInviteRequest { pub struct CreateInviteRequest {
pub room_id: String, pub room_id: String,
@ -130,6 +148,7 @@ pub struct CreateInviteRequest {
// ── WebSocket event types ── // ── WebSocket event types ──
/// Messages the browser can send over the WebSocket connection.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum WsClientMessage { pub enum WsClientMessage {
@ -148,17 +167,14 @@ pub enum WsClientMessage {
Typing { room_id: String }, Typing { room_id: String },
} }
/// Messages the server can push to browsers over the WebSocket connection.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum WsServerMessage { pub enum WsServerMessage {
#[serde(rename = "new_message")] #[serde(rename = "new_message")]
NewMessage { NewMessage { message: MessagePayload },
message: MessagePayload,
},
#[serde(rename = "ai_typing")] #[serde(rename = "ai_typing")]
AiTyping { AiTyping { room_id: String },
room_id: String,
},
#[serde(rename = "user_typing")] #[serde(rename = "user_typing")]
UserTyping { UserTyping {
room_id: String, room_id: String,
@ -166,21 +182,13 @@ pub enum WsServerMessage {
display_name: String, display_name: String,
}, },
#[serde(rename = "error")] #[serde(rename = "error")]
Error { Error { message: String },
message: String,
},
#[serde(rename = "joined")] #[serde(rename = "joined")]
Joined { Joined { room_id: String },
room_id: String,
},
#[serde(rename = "room_deleted")] #[serde(rename = "room_deleted")]
RoomDeleted { RoomDeleted { room_id: String },
room_id: String,
},
#[serde(rename = "room_cleared")] #[serde(rename = "room_cleared")]
RoomCleared { RoomCleared { room_id: String },
room_id: String,
},
#[serde(rename = "ai_tool_usage")] #[serde(rename = "ai_tool_usage")]
AiToolUsage { AiToolUsage {
room_id: String, room_id: String,
@ -194,12 +202,10 @@ pub enum WsServerMessage {
delta: String, delta: String,
}, },
#[serde(rename = "ai_stream_end")] #[serde(rename = "ai_stream_end")]
AiStreamEnd { AiStreamEnd { room_id: String, message_id: String },
room_id: String,
message_id: String,
},
} }
/// Message shape sent to clients for history loading and live updates.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessagePayload { pub struct MessagePayload {
pub id: String, pub id: String,
@ -224,7 +230,7 @@ pub struct MessagePayload {
/// Compute Gravatar-compatible MD5 hash from an email address. /// Compute Gravatar-compatible MD5 hash from an email address.
pub fn gravatar_hash(email: &str) -> String { pub fn gravatar_hash(email: &str) -> String {
use md5::{Md5, Digest}; use md5::{Digest, Md5};
let normalized = email.trim().to_lowercase(); let normalized = email.trim().to_lowercase();
let result = Md5::digest(normalized.as_bytes()); let result = Md5::digest(normalized.as_bytes());
format!("{:x}", result) format!("{:x}", result)
@ -232,13 +238,14 @@ pub fn gravatar_hash(email: &str) -> String {
/// Compute SHA-256 integrity hash from created_at timestamp + message content. /// Compute SHA-256 integrity hash from created_at timestamp + message content.
pub fn message_hash(created_at: &str, content: &str) -> String { pub fn message_hash(created_at: &str, content: &str) -> String {
use sha2::{Sha256, Digest}; use sha2::{Digest, Sha256};
let mut hasher = Sha256::new(); let mut hasher = Sha256::new();
hasher.update(created_at.as_bytes()); hasher.update(created_at.as_bytes());
hasher.update(content.as_bytes()); hasher.update(content.as_bytes());
format!("{:x}", hasher.finalize()) format!("{:x}", hasher.finalize())
} }
/// Usage and tool metadata captured for AI-generated messages.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiMeta { pub struct AiMeta {
pub model: String, pub model: String,
@ -250,6 +257,7 @@ pub struct AiMeta {
pub tool_results: Option<Vec<ToolResult>>, pub tool_results: Option<Vec<ToolResult>>,
} }
/// One tool invocation performed while generating an AI answer.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult { pub struct ToolResult {
pub tool: String, pub tool: String,
@ -259,6 +267,7 @@ pub struct ToolResult {
// ── Broadcast event (internal channel) ── // ── Broadcast event (internal channel) ──
/// Internal event sent through a Tokio broadcast channel to WebSocket tasks.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BroadcastEvent { pub struct BroadcastEvent {
pub room_id: String, pub room_id: String,
@ -267,9 +276,10 @@ pub struct BroadcastEvent {
// ── JWT Claims ── // ── JWT Claims ──
/// Claims stored inside the server-issued JWT.
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Claims { pub struct Claims {
pub sub: String, // user_id pub sub: String, // user_id
pub email: String, pub email: String,
pub display_name: String, pub display_name: String,
pub exp: usize, pub exp: usize,
@ -277,6 +287,7 @@ pub struct Claims {
// ── Pagination ── // ── Pagination ──
/// Common pagination parameters for message history endpoints.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct PaginationParams { pub struct PaginationParams {
#[serde(default = "default_limit")] #[serde(default = "default_limit")]
@ -288,7 +299,7 @@ fn default_limit() -> i64 {
50 50
} }
/// Returns "" if the email is a sentinel nostr: value, otherwise returns it as-is. /// Hide placeholder `nostr:*` emails from normal client responses.
pub fn public_email(email: &str) -> String { pub fn public_email(email: &str) -> String {
if email.starts_with("nostr:") { if email.starts_with("nostr:") {
String::new() String::new()
@ -299,11 +310,13 @@ pub fn public_email(email: &str) -> String {
// ── Nostr auth types ── // ── Nostr auth types ──
/// Response returned by the Nostr challenge endpoint.
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
pub struct NostrChallengeResponse { pub struct NostrChallengeResponse {
pub challenge: String, pub challenge: String,
} }
/// JSON body sent by the client when proving Nostr ownership.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct NostrVerifyRequest { pub struct NostrVerifyRequest {
pub signed_event: String, pub signed_event: String,
@ -312,6 +325,7 @@ pub struct NostrVerifyRequest {
pub profile_picture: Option<String>, pub profile_picture: Option<String>,
} }
/// JSON body for inviting an already-known Nostr user into a room.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct NostrInviteRequest { pub struct NostrInviteRequest {
pub room_id: String, pub room_id: String,

View File

@ -4,6 +4,7 @@ use crate::services::search::SearchResult;
const BRAVE_SEARCH_URL: &str = "https://api.search.brave.com/res/v1/web/search"; const BRAVE_SEARCH_URL: &str = "https://api.search.brave.com/res/v1/web/search";
/// Partial Brave API response containing only the fields this app needs.
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct BraveResponse { struct BraveResponse {
web: Option<BraveWebResults>, web: Option<BraveWebResults>,
@ -27,11 +28,7 @@ struct BraveResult {
/// Search the web using the Brave Search API. /// Search the web using the Brave Search API.
/// Returns a list of simplified search results. /// Returns a list of simplified search results.
pub async fn search( pub async fn search(query: &str, api_key: &str, count: u8) -> Result<Vec<SearchResult>, String> {
query: &str,
api_key: &str,
count: u8,
) -> Result<Vec<SearchResult>, String> {
let count = count.clamp(1, 10); let count = count.clamp(1, 10);
let client = reqwest::Client::new(); let client = reqwest::Client::new();

View File

@ -19,9 +19,29 @@ const STRIP_TAGS: &[&str] = &[
/// Block-level tags that should produce newlines in text output. /// Block-level tags that should produce newlines in text output.
const BLOCK_TAGS: &[&str] = &[ const BLOCK_TAGS: &[&str] = &[
"p", "div", "h1", "h2", "h3", "h4", "h5", "h6", "li", "br", "tr", "p",
"blockquote", "pre", "section", "article", "main", "header", "div",
"dt", "dd", "figcaption", "table", "thead", "tbody", "h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"li",
"br",
"tr",
"blockquote",
"pre",
"section",
"article",
"main",
"header",
"dt",
"dd",
"figcaption",
"table",
"thead",
"tbody",
]; ];
/// Fetch a URL and extract its text content. /// Fetch a URL and extract its text content.

View File

@ -1,3 +1,9 @@
//! Integrations with external systems used by the chat server.
//!
//! These modules wrap search providers, web page fetching, and the OpenRouter
//! chat completion API so the rest of the application can call them with simple
//! Rust types.
pub mod brave; pub mod brave;
pub mod fetch; pub mod fetch;
pub mod openrouter; pub mod openrouter;

View File

@ -235,7 +235,9 @@ pub async fn chat_completion_stream(
{ {
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
let _ = tx.send(StreamEvent::Error(format!("Request failed: {}", e))).await; let _ = tx
.send(StreamEvent::Error(format!("Request failed: {}", e)))
.await;
return; return;
} }
}; };
@ -243,7 +245,12 @@ pub async fn chat_completion_stream(
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let body = response.text().await.unwrap_or_default(); let body = response.text().await.unwrap_or_default();
let _ = tx.send(StreamEvent::Error(format!("OpenRouter error {}: {}", status, body))).await; let _ = tx
.send(StreamEvent::Error(format!(
"OpenRouter error {}: {}",
status, body
)))
.await;
return; return;
} }
@ -264,7 +271,9 @@ pub async fn chat_completion_stream(
let bytes = match chunk_result { let bytes = match chunk_result {
Ok(b) => b, Ok(b) => b,
Err(e) => { Err(e) => {
let _ = tx.send(StreamEvent::Error(format!("Stream error: {}", e))).await; let _ = tx
.send(StreamEvent::Error(format!("Stream error: {}", e)))
.await;
return; return;
} }
}; };
@ -338,7 +347,10 @@ pub async fn chat_completion_stream(
tool_call_accum[idx].function.name.push_str(name); tool_call_accum[idx].function.name.push_str(name);
} }
if let Some(args) = &func.arguments { if let Some(args) = &func.arguments {
tool_call_accum[idx].function.arguments.push_str(args); tool_call_accum[idx]
.function
.arguments
.push_str(args);
} }
} }
} }
@ -373,7 +385,11 @@ pub async fn chat_completion_stream(
// AI requested tool calls // AI requested tool calls
let assistant_msg = ChatMessage { let assistant_msg = ChatMessage {
role: "assistant".into(), role: "assistant".into(),
content: if full_content.is_empty() { None } else { Some(Content::Text(full_content)) }, content: if full_content.is_empty() {
None
} else {
Some(Content::Text(full_content))
},
tool_calls: Some(tool_call_accum), tool_calls: Some(tool_call_accum),
tool_call_id: None, tool_call_id: None,
}; };
@ -420,7 +436,9 @@ pub fn build_chat_history(
Content::Parts(vec![ Content::Parts(vec![
ContentPart::Text { text }, ContentPart::Text { text },
ContentPart::ImageUrl { ContentPart::ImageUrl {
image_url: ImageUrlData { url: data_url.clone() }, image_url: ImageUrlData {
url: data_url.clone(),
},
}, },
]) ])
} else { } else {

View File

@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};
use super::{brave, tavily}; use super::{brave, tavily};
/// Which search backend the AI tool layer should call.
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SearchProvider { pub enum SearchProvider {
Tavily, Tavily,
@ -9,8 +10,14 @@ pub enum SearchProvider {
} }
impl SearchProvider { impl SearchProvider {
/// Parse the `SEARCH_PROVIDER` environment variable into a supported variant.
pub fn from_env(value: Option<&str>) -> Result<Self, String> { pub fn from_env(value: Option<&str>) -> Result<Self, String> {
match value.unwrap_or("tavily").trim().to_ascii_lowercase().as_str() { match value
.unwrap_or("tavily")
.trim()
.to_ascii_lowercase()
.as_str()
{
"tavily" => Ok(Self::Tavily), "tavily" => Ok(Self::Tavily),
"brave" => Ok(Self::Brave), "brave" => Ok(Self::Brave),
other => Err(format!( other => Err(format!(
@ -20,6 +27,7 @@ impl SearchProvider {
} }
} }
/// Return the environment variable name required by the selected provider.
pub fn required_key_name(self) -> &'static str { pub fn required_key_name(self) -> &'static str {
match self { match self {
Self::Tavily => "TAVILY_API_KEY", Self::Tavily => "TAVILY_API_KEY",
@ -28,6 +36,7 @@ impl SearchProvider {
} }
} }
/// Normalized search result shape shared across providers.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult { pub struct SearchResult {
pub title: String, pub title: String,
@ -36,6 +45,7 @@ pub struct SearchResult {
pub age: Option<String>, pub age: Option<String>,
} }
/// Dispatch a search request to whichever provider the server is configured to use.
pub async fn search( pub async fn search(
provider: SearchProvider, provider: SearchProvider,
query: &str, query: &str,
@ -59,6 +69,7 @@ pub async fn search(
} }
} }
/// Turn search results into plain text the AI model can read as tool output.
pub fn format_results(results: &[SearchResult]) -> String { pub fn format_results(results: &[SearchResult]) -> String {
if results.is_empty() { if results.is_empty() {
return "No search results found.".to_string(); return "No search results found.".to_string();

View File

@ -23,11 +23,7 @@ struct TavilyResult {
published_date: Option<String>, published_date: Option<String>,
} }
pub async fn search( pub async fn search(query: &str, api_key: &str, count: u8) -> Result<Vec<SearchResult>, String> {
query: &str,
api_key: &str,
count: u8,
) -> Result<Vec<SearchResult>, String> {
let max_results = count.clamp(1, 10); let max_results = count.clamp(1, 10);
let client = reqwest::Client::new(); let client = reqwest::Client::new();
@ -75,7 +71,10 @@ pub async fn search(
if first.description.is_empty() { if first.description.is_empty() {
first.description = format!("AI summary: {}", answer); first.description = format!("AI summary: {}", answer);
} else { } else {
first.description = format!("AI summary: {}\nSource excerpt: {}", answer, first.description); first.description = format!(
"AI summary: {}\nSource excerpt: {}",
answer, first.description
);
} }
} else { } else {
results.push(SearchResult { results.push(SearchResult {