diff --git a/server/src/handlers/auth.rs b/server/src/handlers/auth.rs index 3189440..54f1cd0 100644 --- a/server/src/handlers/auth.rs +++ b/server/src/handlers/auth.rs @@ -1,8 +1,8 @@ -use axum::{extract::State, http::StatusCode, Json}; use argon2::{ password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}, Argon2, }; +use axum::{extract::State, http::StatusCode, Json}; use std::sync::Arc; use uuid::Uuid; @@ -12,6 +12,7 @@ use crate::{ AppState, }; +/// Create a new password-based account and immediately return a JWT. pub async fn register( State(state): State>, Json(body): Json, @@ -69,6 +70,7 @@ pub async fn register( })) } +/// Authenticate an existing password-based account and return a fresh JWT. pub async fn login( State(state): State>, Json(body): Json, @@ -87,8 +89,8 @@ pub async fn login( let (user_id, email, display_name, hash, avatar_url) = user; - let parsed_hash = PasswordHash::new(&hash) - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + let parsed_hash = + PasswordHash::new(&hash).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; Argon2::default() .verify_password(body.password.as_bytes(), &parsed_hash) @@ -109,6 +111,7 @@ pub async fn login( })) } +/// Return the caller's current public profile information. pub async fn me( auth: AuthUser, State(state): State>, diff --git a/server/src/handlers/invites.rs b/server/src/handlers/invites.rs index 157e883..c1bf65f 100644 --- a/server/src/handlers/invites.rs +++ b/server/src/handlers/invites.rs @@ -13,6 +13,7 @@ use crate::{ AppState, }; +/// Response payload for a newly created invite link. #[derive(serde::Serialize)] pub struct InviteResponse { pub id: String, @@ -20,6 +21,7 @@ pub struct InviteResponse { pub invite_url: String, } +/// Create a one-time invite token for a room member to share. pub async fn create_invite( State(state): State>, auth: AuthUser, @@ -46,15 +48,17 @@ pub async fn create_invite( .map(char::from) .collect(); - sqlx::query("INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)") - .bind(&invite_id) - .bind(&body.room_id) - .bind(&auth.user_id) - .bind(&body.email) - .bind(&token) - .execute(&state.db) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + sqlx::query( + "INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)", + ) + .bind(&invite_id) + .bind(&body.room_id) + .bind(&auth.user_id) + .bind(&body.email) + .bind(&token) + .execute(&state.db) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; Ok(Json(InviteResponse { id: invite_id, @@ -63,11 +67,13 @@ pub async fn create_invite( })) } +/// Response payload returned after consuming an invite. #[derive(serde::Serialize)] pub struct AcceptInviteResponse { pub room_id: String, } +/// Consume an invite token and add the caller to the room. pub async fn accept_invite( State(state): State>, auth: AuthUser, @@ -89,13 +95,12 @@ pub async fn accept_invite( } // Verify room is not deleted - let room_active = sqlx::query_scalar::<_, String>( - "SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL", - ) - .bind(&room_id) - .fetch_optional(&state.db) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + let room_active = + sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL") + .bind(&room_id) + .fetch_optional(&state.db) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; if room_active.is_none() { return Err((StatusCode::GONE, "This room has been deleted".into())); @@ -119,6 +124,7 @@ pub async fn accept_invite( Ok(Json(AcceptInviteResponse { room_id })) } +/// Result of a Nostr-based room invite attempt. #[derive(serde::Serialize)] pub struct NostrInviteResponse { pub status: String, @@ -126,6 +132,7 @@ pub struct NostrInviteResponse { pub display_name: Option, } +/// Add a user to a room by their Nostr public key if they already have an account. pub async fn invite_by_nostr( State(state): State>, auth: AuthUser, @@ -138,8 +145,13 @@ pub async fn invite_by_nostr( .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid npub format".to_string()))? } else { // Validate it's valid hex - if body.nostr_pubkey.len() != 64 || !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit()) { - return Err((StatusCode::BAD_REQUEST, "Invalid pubkey: must be 64-char hex or npub".to_string())); + if body.nostr_pubkey.len() != 64 + || !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit()) + { + return Err(( + StatusCode::BAD_REQUEST, + "Invalid pubkey: must be 64-char hex or npub".to_string(), + )); } body.nostr_pubkey.clone() }; diff --git a/server/src/handlers/mod.rs b/server/src/handlers/mod.rs index 2861e62..cb34c9b 100644 --- a/server/src/handlers/mod.rs +++ b/server/src/handlers/mod.rs @@ -1,3 +1,8 @@ +//! HTTP and WebSocket entry points for the server. +//! +//! Each submodule exposes route handlers that Axum wires into the router in +//! `main.rs`. + pub mod auth; pub mod invites; pub mod models; diff --git a/server/src/handlers/models.rs b/server/src/handlers/models.rs index 1b63a20..b48b6f0 100644 --- a/server/src/handlers/models.rs +++ b/server/src/handlers/models.rs @@ -1,19 +1,16 @@ -use axum::{ - extract::State, - http::StatusCode, - Json, -}; +use axum::{extract::State, http::StatusCode, Json}; use serde::{Deserialize, Serialize}; use std::sync::Arc; -use tokio::sync::OnceCell; use std::time::{Duration, Instant}; use tokio::sync::Mutex; +use tokio::sync::OnceCell; use crate::AppState; /// Cached model list with expiry. static MODEL_CACHE: OnceCell> = OnceCell::const_new(); +/// Process-wide cache for the OpenRouter model catalog. struct CachedModels { models: Vec, fetched_at: Instant, @@ -21,6 +18,7 @@ struct CachedModels { const CACHE_TTL: Duration = Duration::from_secs(60 * 30); // 30 minutes +/// Model metadata exposed to the client for room creation and model selection. #[derive(Debug, Clone, Serialize)] pub struct ModelInfo { pub id: String, @@ -56,6 +54,10 @@ struct OpenRouterArchitecture { input_modalities: Option>, } +/// Fetch the model catalog directly from OpenRouter. +/// +/// The result is normalized into the smaller `ModelInfo` shape that the client +/// UI needs. async fn fetch_models(api_key: &str) -> Result, String> { let client = reqwest::Client::new(); @@ -82,7 +84,8 @@ async fn fetch_models(api_key: &str) -> Result, String> { .into_iter() .map(|m| { let pricing = m.pricing.as_ref(); - let supports_vision = m.architecture + let supports_vision = m + .architecture .as_ref() .and_then(|a| a.input_modalities.as_ref()) .map(|mods| mods.iter().any(|m| m == "image")) @@ -102,6 +105,7 @@ async fn fetch_models(api_key: &str) -> Result, String> { Ok(models) } +/// Return the cached OpenRouter model list, refreshing it when the cache expires. pub async fn list_models( State(state): State>, ) -> Result>, (StatusCode, String)> { diff --git a/server/src/handlers/nostr_auth.rs b/server/src/handlers/nostr_auth.rs index dd862f7..1bacdc6 100644 --- a/server/src/handlers/nostr_auth.rs +++ b/server/src/handlers/nostr_auth.rs @@ -10,6 +10,7 @@ use crate::{ AppState, }; +/// Claims embedded in the short-lived challenge token used during Nostr login. #[derive(Debug, serde::Serialize, serde::Deserialize)] struct ChallengeClaims { pub nonce: String, @@ -28,10 +29,7 @@ pub async fn challenge( let exp = (chrono::Utc::now().timestamp() + 120) as usize; // 2 minutes - let claims = ChallengeClaims { - nonce, - exp, - }; + let claims = ChallengeClaims { nonce, exp }; let token = encode( &Header::default(), @@ -45,6 +43,7 @@ pub async fn challenge( /// Simple hex encoder (avoid adding the `hex` crate just for this) mod hex { + /// Convert raw bytes into a lowercase hexadecimal string. pub fn encode(bytes: &[u8]) -> String { bytes.iter().map(|b| format!("{:02x}", b)).collect() } @@ -61,17 +60,29 @@ pub async fn verify( &DecodingKey::from_secret(state.jwt_secret.as_bytes()), &Validation::default(), ) - .map_err(|_| (StatusCode::BAD_REQUEST, "Invalid or expired challenge".to_string()))?; + .map_err(|_| { + ( + StatusCode::BAD_REQUEST, + "Invalid or expired challenge".to_string(), + ) + })?; let nonce = &challenge_data.claims.nonce; // 2. Deserialize signed_event as nostr::Event - let event: Event = serde_json::from_str(&body.signed_event) - .map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid event JSON: {}", e)))?; + let event: Event = serde_json::from_str(&body.signed_event).map_err(|e| { + ( + StatusCode::BAD_REQUEST, + format!("Invalid event JSON: {}", e), + ) + })?; // 3. Verify Schnorr signature if !event.verify_signature() { - return Err((StatusCode::UNAUTHORIZED, "Invalid event signature".to_string())); + return Err(( + StatusCode::UNAUTHORIZED, + "Invalid event signature".to_string(), + )); } // 4. Verify event.content == nonce @@ -83,7 +94,10 @@ pub async fn verify( let now = chrono::Utc::now().timestamp() as u64; let event_ts = event.created_at.as_secs(); if now.abs_diff(event_ts) > 300 { - return Err((StatusCode::BAD_REQUEST, "Event timestamp too far off".to_string())); + return Err(( + StatusCode::BAD_REQUEST, + "Event timestamp too far off".to_string(), + )); } // 6. Extract pubkey hex diff --git a/server/src/handlers/profile.rs b/server/src/handlers/profile.rs index d8b6c4e..8a55bb8 100644 --- a/server/src/handlers/profile.rs +++ b/server/src/handlers/profile.rs @@ -11,6 +11,7 @@ use crate::{ AppState, }; +/// Request body for profile updates. #[derive(Debug, serde::Deserialize)] pub struct UpdateProfileRequest { pub display_name: Option, @@ -25,7 +26,10 @@ pub async fn update_profile( let display_name = body.display_name.unwrap_or(auth.display_name.clone()); if display_name.trim().is_empty() { - return Err((StatusCode::BAD_REQUEST, "Display name cannot be empty".into())); + return Err(( + StatusCode::BAD_REQUEST, + "Display name cannot be empty".into(), + )); } sqlx::query("UPDATE users SET display_name = ? WHERE id = ?") @@ -83,7 +87,12 @@ pub async fn upload_avatar( "image/jpeg" | "image/jpg" => "jpg", "image/gif" => "gif", "image/webp" => "webp", - _ => return Err((StatusCode::BAD_REQUEST, "Only PNG, JPG, GIF, and WebP images are allowed".into())), + _ => { + return Err(( + StatusCode::BAD_REQUEST, + "Only PNG, JPG, GIF, and WebP images are allowed".into(), + )) + } }; let data = field @@ -130,8 +139,13 @@ pub async fn upload_avatar( .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; // Issue new token - let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret) - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + let token = create_token( + &auth.user_id, + &auth.email, + &auth.display_name, + &state.jwt_secret, + ) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; Ok(Json(AuthResponse { token, @@ -168,8 +182,13 @@ pub async fn delete_avatar( .await .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; - let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret) - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + let token = create_token( + &auth.user_id, + &auth.email, + &auth.display_name, + &state.jwt_secret, + ) + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; Ok(Json(AuthResponse { token, diff --git a/server/src/handlers/rooms.rs b/server/src/handlers/rooms.rs index 0930890..5150566 100644 --- a/server/src/handlers/rooms.rs +++ b/server/src/handlers/rooms.rs @@ -8,10 +8,13 @@ use uuid::Uuid; use crate::{ middleware::auth::AuthUser, - models::{self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic}, + models::{ + self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic, + }, AppState, }; +/// Create a room, persist it, and add the creator as the first member. pub async fn create_room( State(state): State>, auth: AuthUser, @@ -60,6 +63,7 @@ pub async fn create_room( })) } +/// List all active rooms the caller belongs to, including current room members. pub async fn list_rooms( State(state): State>, auth: AuthUser, @@ -93,13 +97,15 @@ pub async fn list_rooms( created_at: room.created_at, members: members .into_iter() - .map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic { - id, - email: models::public_email(&email), - display_name, - avatar_url, - nostr_pubkey, - }) + .map( + |(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic { + id, + email: models::public_email(&email), + display_name, + avatar_url, + nostr_pubkey, + }, + ) .collect(), }); } @@ -107,6 +113,7 @@ pub async fn list_rooms( Ok(Json(result)) } +/// Return details for a single room after verifying the caller is a member. pub async fn get_room( State(state): State>, auth: AuthUser, @@ -152,17 +159,20 @@ pub async fn get_room( created_at: room.created_at, members: members .into_iter() - .map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic { - id, - email: models::public_email(&email), - display_name, - avatar_url, - nostr_pubkey, - }) + .map( + |(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic { + id, + email: models::public_email(&email), + display_name, + avatar_url, + nostr_pubkey, + }, + ) .collect(), })) } +/// Return paginated message history for a room the caller can access. pub async fn get_messages( State(state): State>, auth: AuthUser, @@ -208,37 +218,56 @@ pub async fn get_messages( } .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + // The SQL query reads newest-first for efficient pagination, but clients + // render chat oldest-to-newest, so reverse the rows before serializing. let payloads: Vec = rows .into_iter() .rev() - .map(|(id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at, ai_meta_str, image_url, email, avatar_url, hash)| { - let ai_meta = ai_meta_str - .as_deref() - .and_then(|s| serde_json::from_str::(s).ok()); - let avatar_hash = email - .map(|e| crate::models::gravatar_hash(&e)) - .unwrap_or_default(); - MessagePayload { + .map( + |( id, room_id, sender_id, sender_name, content, - mentions: serde_json::from_str(&mentions).unwrap_or_default(), + mentions, is_ai, created_at, - ai_meta, - avatar_hash, - avatar_url, + ai_meta_str, image_url, + email, + avatar_url, hash, - } - }) + )| { + let ai_meta = ai_meta_str + .as_deref() + .and_then(|s| serde_json::from_str::(s).ok()); + let avatar_hash = email + .map(|e| crate::models::gravatar_hash(&e)) + .unwrap_or_default(); + MessagePayload { + id, + room_id, + sender_id, + sender_name, + content, + mentions: serde_json::from_str(&mentions).unwrap_or_default(), + is_ai, + created_at, + ai_meta, + avatar_hash, + avatar_url, + image_url, + hash, + } + }, + ) .collect(); Ok(Json(payloads)) } +/// Resolve a stable message hash into the room that contains it. pub async fn resolve_message_hash( State(state): State>, auth: AuthUser, @@ -258,22 +287,29 @@ pub async fn resolve_message_hash( .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; match row { - Some((room_id,)) => Ok(Json(serde_json::json!({ "room_id": room_id, "hash": hash }))), - None => Err((StatusCode::NOT_FOUND, "Message not found or no access".into())), + Some((room_id,)) => Ok(Json( + serde_json::json!({ "room_id": room_id, "hash": hash }), + )), + None => Err(( + StatusCode::NOT_FOUND, + "Message not found or no access".into(), + )), } } +/// Add the caller to a room directly when they already know its ID. pub async fn join_room( State(state): State>, auth: AuthUser, Path(room_id): Path, ) -> Result { // Check room exists - let room_exists = sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL") - .bind(&room_id) - .fetch_optional(&state.db) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + let room_exists = + sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL") + .bind(&room_id) + .fetch_optional(&state.db) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; if room_exists.is_none() { return Err((StatusCode::NOT_FOUND, "Room not found".into())); @@ -289,6 +325,7 @@ pub async fn join_room( Ok(StatusCode::OK) } +/// Soft-delete a room and broadcast the deletion event to connected members. pub async fn delete_room( State(state): State>, auth: AuthUser, @@ -303,7 +340,10 @@ pub async fn delete_room( .ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?; if room.created_by != auth.user_id { - return Err((StatusCode::FORBIDDEN, "Only the room creator can delete this room".into())); + return Err(( + StatusCode::FORBIDDEN, + "Only the room creator can delete this room".into(), + )); } // Soft-delete @@ -324,6 +364,7 @@ pub async fn delete_room( Ok(StatusCode::OK) } +/// Permanently remove all messages from a room without deleting the room itself. pub async fn clear_room( State(state): State>, auth: AuthUser, @@ -338,7 +379,10 @@ pub async fn clear_room( .ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?; if room.created_by != auth.user_id { - return Err((StatusCode::FORBIDDEN, "Only the room creator can clear messages".into())); + return Err(( + StatusCode::FORBIDDEN, + "Only the room creator can clear messages".into(), + )); } // Hard-delete all messages diff --git a/server/src/handlers/upload.rs b/server/src/handlers/upload.rs index 107a026..90f5900 100644 --- a/server/src/handlers/upload.rs +++ b/server/src/handlers/upload.rs @@ -1,18 +1,16 @@ -use axum::{ - extract::Multipart, - http::StatusCode, - Json, -}; +use axum::{extract::Multipart, http::StatusCode, Json}; use serde::Serialize; use uuid::Uuid; use crate::middleware::auth::AuthUser; +/// Response returned after a chat image upload succeeds. #[derive(Serialize)] pub struct UploadResponse { pub url: String, } +/// Accept a multipart chat image upload and store it under `uploads/chat-images`. pub async fn upload_chat_image( _auth: AuthUser, mut multipart: Multipart, diff --git a/server/src/handlers/ws.rs b/server/src/handlers/ws.rs index 2790b11..6221c14 100644 --- a/server/src/handlers/ws.rs +++ b/server/src/handlers/ws.rs @@ -1,3 +1,9 @@ +//! WebSocket workflow for live chat delivery and AI responses. +//! +//! This module does two jobs: +//! - fan out database-backed room events to subscribed browser sockets +//! - turn incoming user chat messages into stored messages and optional AI replies + use axum::{ extract::{ ws::{Message, WebSocket}, @@ -24,6 +30,7 @@ pub struct WsQuery { token: String, } +/// Upgrade an authenticated request into a WebSocket connection. pub async fn ws_handler( ws: WebSocketUpgrade, State(state): State>, @@ -37,10 +44,19 @@ pub async fn ws_handler( } }; - ws.on_upgrade(move |socket| handle_socket(socket, state, claims.sub, claims.display_name, claims.email)) + ws.on_upgrade(move |socket| { + handle_socket(socket, state, claims.sub, claims.display_name, claims.email) + }) } -async fn handle_socket(socket: WebSocket, state: Arc, user_id: String, display_name: String, email: String) { +/// Drive a single WebSocket connection until either the send or receive side ends. +async fn handle_socket( + socket: WebSocket, + state: Arc, + user_id: String, + display_name: String, + email: String, +) { let (mut ws_tx, mut ws_rx) = socket.split(); let mut broadcast_rx = state.tx.subscribe(); @@ -50,7 +66,8 @@ async fn handle_socket(socket: WebSocket, state: Arc, user_id: String, let rooms_clone = subscribed_rooms.clone(); - // Task: forward broadcast events to this client + // Task 1: forward room events from the shared broadcast channel into this + // specific socket, but only for rooms the browser subscribed to. let mut send_task = tokio::spawn(async move { loop { match broadcast_rx.recv().await { @@ -81,7 +98,8 @@ async fn handle_socket(socket: WebSocket, state: Arc, user_id: String, let email_clone = email.clone(); let rooms_clone2 = subscribed_rooms.clone(); - // Task: receive messages from client + // Task 2: receive commands from the browser and translate them into + // database writes, broadcasts, or AI work. let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = ws_rx.next().await { let text = match msg { @@ -141,7 +159,7 @@ async fn handle_socket(socket: WebSocket, state: Arc, user_id: String, } }); - // Wait for either task to finish, then abort the other + // If either half of the connection ends, stop the companion task too. tokio::select! { _ = &mut send_task => recv_task.abort(), _ = &mut recv_task => send_task.abort(), @@ -150,6 +168,7 @@ async fn handle_socket(socket: WebSocket, state: Arc, user_id: String, tracing::info!("WebSocket disconnected: {}", user_id); } +/// Persist a user message, broadcast it, and optionally generate an AI reply. async fn handle_send_message( state: &Arc, user_id: &str, @@ -184,13 +203,14 @@ async fn handle_send_message( .await; // Look up the sender's custom avatar (if any) for the message payload - let avatar_url: Option = sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?") - .bind(user_id) - .fetch_optional(&state.db) - .await - .ok() - .flatten() - .flatten(); + let avatar_url: Option = + sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?") + .bind(user_id) + .fetch_optional(&state.db) + .await + .ok() + .flatten() + .flatten(); // Broadcast human message let payload = MessagePayload { @@ -211,12 +231,11 @@ async fn handle_send_message( let _ = state.tx.send(BroadcastEvent { room_id: room_id.to_string(), - message: WsServerMessage::NewMessage { - message: payload, - }, + message: WsServerMessage::NewMessage { message: payload }, }); - // Check if AI should respond + // The AI only replies when explicitly mentioned or when the room is set to + // auto-reply to every message. let ai_user_id = "ai-assistant"; let should_respond = mentions.contains(&ai_user_id.to_string()); @@ -254,7 +273,8 @@ async fn handle_send_message( .await .unwrap_or_default(); - // Process history: encode images as base64 data URLs for OpenRouter + // OpenRouter accepts image inputs as data URLs, so local uploads need to be + // loaded from disk and encoded before they are sent upstream. let mut history: Vec<(String, String, bool, Option)> = Vec::new(); for (sender_name, msg_content, is_ai, msg_image_url) in recent_messages.into_iter().rev() { let image_data_url = match &msg_image_url { @@ -272,7 +292,8 @@ async fn handle_send_message( // Pre-generate AI message ID so we can reference it in stream chunks let ai_msg_id = Uuid::new_v4().to_string(); - // Call OpenRouter with tool loop — uses streaming for all rounds + // Run the AI in a loop because the model may first request tools, then need + // follow-up rounds after those tool results are added to history. let mut total_prompt_tokens: u32 = 0; let mut total_completion_tokens: u32 = 0; let mut total_response_ms: u64 = 0; @@ -313,16 +334,24 @@ async fn handle_send_message( tracing::info!( "AI requesting tool calls (round {}): {:?}", round + 1, - assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::>()) + assistant_msg + .tool_calls + .as_ref() + .map(|tc| tc.iter().map(|t| &t.function.name).collect::>()) ); - // Add the assistant's tool-call message to history + // Preserve the assistant tool-call message so the next round + // has the same context the model produced. let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default(); chat_history.push(assistant_msg); - // Execute each tool call and add results + // Tool results are fed back into the conversation as + // synthetic `tool` messages, matching the upstream API. for tool_call in &tool_calls { - let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments); + let tool_input = extract_tool_input( + &tool_call.function.name, + &tool_call.function.arguments, + ); // Broadcast real-time tool usage event let _ = state.tx.send(BroadcastEvent { @@ -362,7 +391,7 @@ async fn handle_send_message( tool_call_id: Some(tool_call.id.clone()), }); } - // Continue to next round (tool loop) + // Ask the model to continue now that tool output exists. continue 'tool_loop; } openrouter::StreamEvent::Done(stats) => { @@ -382,9 +411,12 @@ async fn handle_send_message( } } - // If we exhausted all rounds without a text response, note it + // Guardrail: if the model never produced final prose, store a clear fallback + // instead of leaving the client waiting indefinitely. if ai_response.is_empty() && !had_error { - ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string(); + ai_response = + "*I used several tools but couldn't formulate a final response. Please try again.*" + .to_string(); } // Signal stream end so client can finalize rendering @@ -512,7 +544,15 @@ async fn execute_tool( return "Error: search query is required".into(); } - match search::search(search_provider, &query, tavily_api_key, brave_api_key, count).await { + match search::search( + search_provider, + &query, + tavily_api_key, + brave_api_key, + count, + ) + .await + { Ok(results) => search::format_results(&results), Err(e) => format!("Search error: {}", e), } diff --git a/server/src/main.rs b/server/src/main.rs index 549cf73..06881dc 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,3 +1,12 @@ +//! Application bootstrap for the GroupChat server. +//! +//! This file is responsible for: +//! - loading environment configuration +//! - opening and migrating the SQLite database +//! - constructing shared application state +//! - registering HTTP/WebSocket routes +//! - serving the SPA frontend in production + mod handlers; mod middleware; mod models; @@ -51,14 +60,8 @@ fn backup_database(database_url: &str) { } // Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db - let stem = db_file - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("db"); - let ext = db_file - .extension() - .and_then(|s| s.to_str()) - .unwrap_or("db"); + let stem = db_file.file_stem().and_then(|s| s.to_str()).unwrap_or("db"); + let ext = db_file.extension().and_then(|s| s.to_str()).unwrap_or("db"); let now = chrono::Local::now(); let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext); @@ -82,11 +85,21 @@ fn backup_database(database_url: &str) { let wal_path = format!("{}-wal", db_path); let shm_path = format!("{}-shm", db_path); if std::path::Path::new(&wal_path).exists() { - let wal_backup = backup_dir.join(format!("{}_{}.{}-wal", stem, now.format("%Y-%m-%d_%H%M%S"), ext)); + let wal_backup = backup_dir.join(format!( + "{}_{}.{}-wal", + stem, + now.format("%Y-%m-%d_%H%M%S"), + ext + )); let _ = std::fs::copy(&wal_path, &wal_backup); } if std::path::Path::new(&shm_path).exists() { - let shm_backup = backup_dir.join(format!("{}_{}.{}-shm", stem, now.format("%Y-%m-%d_%H%M%S"), ext)); + let shm_backup = backup_dir.join(format!( + "{}_{}.{}-shm", + stem, + now.format("%Y-%m-%d_%H%M%S"), + ext + )); let _ = std::fs::copy(&shm_path, &shm_backup); } @@ -119,20 +132,35 @@ fn prune_old_backups(backup_dir: &std::path::Path, stem: &str, keep: usize) { let to_remove = backups.len() - keep; for entry in backups.into_iter().take(to_remove) { let path = entry.path(); - let name = path.file_name().unwrap_or_default().to_string_lossy().to_string(); + let name = path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); if let Err(e) = std::fs::remove_file(&path) { tracing::warn!("Failed to remove old backup {}: {}", name, e); } else { tracing::debug!("Pruned old backup: {}", name); // Also remove associated WAL/SHM backups - let wal = path.with_extension(format!("{}-wal", path.extension().unwrap_or_default().to_string_lossy())); - let shm = path.with_extension(format!("{}-shm", path.extension().unwrap_or_default().to_string_lossy())); + let wal = path.with_extension(format!( + "{}-wal", + path.extension().unwrap_or_default().to_string_lossy() + )); + let shm = path.with_extension(format!( + "{}-shm", + path.extension().unwrap_or_default().to_string_lossy() + )); let _ = std::fs::remove_file(&wal); let _ = std::fs::remove_file(&shm); } } } +/// Shared state injected into every handler. +/// +/// Axum stores this behind an `Arc`, so handlers can cheaply clone the pointer +/// while all requests still talk to the same database pool, API keys, and +/// broadcast channel. pub struct AppState { pub db: sqlx::SqlitePool, pub jwt_secret: String, @@ -154,11 +182,15 @@ async fn main() { .with(tracing_subscriber::fmt::layer()) .init(); - let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into()); + // Load the runtime configuration needed to start the server. + let database_url = + std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into()); let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into()); - let openrouter_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); - let search_provider = SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref()) - .unwrap_or_else(|e| panic!("{}", e)); + let openrouter_key = + std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set"); + let search_provider = + SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref()) + .unwrap_or_else(|e| panic!("{}", e)); let tavily_api_key = std::env::var("TAVILY_API_KEY").ok(); let brave_api_key = std::env::var("BRAVE_API_KEY").ok(); @@ -181,7 +213,8 @@ async fn main() { .await .expect("Failed to connect to database"); - // Run migrations + // Run migrations in order. Each one is written so startup can safely try it + // again and skip work that already happened in an earlier run. let migration_sql = include_str!("../migrations/001_init.sql"); sqlx::raw_sql(migration_sql) .execute(&db) @@ -282,6 +315,8 @@ async fn main() { tracing::info!("Database initialized"); + // WebSocket tasks subscribe to this channel to receive room events without + // polling the database. let (tx, _rx) = broadcast::channel::(4096); let state = Arc::new(AppState { @@ -302,32 +337,61 @@ async fn main() { // Serve static files from client dist in production let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into()); + // Keep API routes separate from the static-file fallback so `/api/*` and + // `/ws` requests never get mistaken for SPA routes. let api_routes = Router::new() // Auth routes .route("/api/auth/register", post(handlers::auth::register)) .route("/api/auth/login", post(handlers::auth::login)) .route("/api/auth/me", get(handlers::auth::me)) // Nostr auth routes - .route("/api/auth/nostr/challenge", get(handlers::nostr_auth::challenge)) + .route( + "/api/auth/nostr/challenge", + get(handlers::nostr_auth::challenge), + ) .route("/api/auth/nostr/verify", post(handlers::nostr_auth::verify)) // Profile routes .route("/api/auth/profile", put(handlers::profile::update_profile)) - .route("/api/auth/avatar", post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar)) + .route( + "/api/auth/avatar", + post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar), + ) // Room routes - .route("/api/rooms", get(handlers::rooms::list_rooms).post(handlers::rooms::create_room)) - .route("/api/rooms/:room_id", get(handlers::rooms::get_room).delete(handlers::rooms::delete_room)) - .route("/api/rooms/:room_id/messages", get(handlers::rooms::get_messages)) + .route( + "/api/rooms", + get(handlers::rooms::list_rooms).post(handlers::rooms::create_room), + ) + .route( + "/api/rooms/:room_id", + get(handlers::rooms::get_room).delete(handlers::rooms::delete_room), + ) + .route( + "/api/rooms/:room_id/messages", + get(handlers::rooms::get_messages), + ) .route("/api/rooms/:room_id/join", post(handlers::rooms::join_room)) - .route("/api/rooms/:room_id/clear", post(handlers::rooms::clear_room)) - .route("/api/messages/hash/:hash", get(handlers::rooms::resolve_message_hash)) + .route( + "/api/rooms/:room_id/clear", + post(handlers::rooms::clear_room), + ) + .route( + "/api/messages/hash/:hash", + get(handlers::rooms::resolve_message_hash), + ) // Upload (chat images) .route("/api/upload", post(handlers::upload::upload_chat_image)) // Models .route("/api/models", get(handlers::models::list_models)) // Invite routes .route("/api/invites", post(handlers::invites::create_invite)) - .route("/api/invites/:token/accept", post(handlers::invites::accept_invite)) - .route("/api/invites/nostr", post(handlers::invites::invite_by_nostr)) + .route( + "/api/invites/:token/accept", + post(handlers::invites::accept_invite), + ) + .route( + "/api/invites/nostr", + post(handlers::invites::invite_by_nostr), + ) // Uploaded files (avatars) .nest_service("/uploads", ServeDir::new("uploads")) // WebSocket diff --git a/server/src/middleware/auth.rs b/server/src/middleware/auth.rs index 37610e9..311742b 100644 --- a/server/src/middleware/auth.rs +++ b/server/src/middleware/auth.rs @@ -1,14 +1,11 @@ use async_trait::async_trait; -use axum::{ - extract::FromRequestParts, - http::request::Parts, -}; +use axum::{extract::FromRequestParts, http::request::Parts}; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; use std::sync::Arc; use crate::{models::Claims, AppState}; -/// Extract authenticated user from JWT in Authorization header +/// Authenticated user information extracted from the bearer token. pub struct AuthUser { pub user_id: String, pub email: String, @@ -19,7 +16,15 @@ pub struct AuthUser { impl FromRequestParts> for AuthUser { type Rejection = axum::http::StatusCode; - async fn from_request_parts(parts: &mut Parts, state: &Arc) -> Result { + /// Read the `Authorization: Bearer ` header and decode the JWT. + /// + /// Axum runs this automatically for any handler parameter of type + /// `AuthUser`, which keeps individual handlers free from repeated token + /// parsing logic. + async fn from_request_parts( + parts: &mut Parts, + state: &Arc, + ) -> Result { let auth_header = parts .headers .get("Authorization") @@ -41,7 +46,16 @@ impl FromRequestParts> for AuthUser { } } -pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str) -> Result { +/// Create a signed JWT for a logged-in user. +/// +/// The token expires after seven days and carries the small amount of identity +/// data the server wants available on every request. +pub fn create_token( + user_id: &str, + email: &str, + display_name: &str, + secret: &str, +) -> Result { let expiration = chrono::Utc::now() .checked_add_signed(chrono::Duration::days(7)) .unwrap() @@ -61,6 +75,7 @@ pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str ) } +/// Decode and validate a previously issued JWT. pub fn decode_token(token: &str, secret: &str) -> Result { let token_data = decode::( token, diff --git a/server/src/middleware/mod.rs b/server/src/middleware/mod.rs index 0e4a05d..70723e0 100644 --- a/server/src/middleware/mod.rs +++ b/server/src/middleware/mod.rs @@ -1 +1,3 @@ +//! Reusable request-processing layers shared across handlers. + pub mod auth; diff --git a/server/src/models/mod.rs b/server/src/models/mod.rs index 455a98c..7a272bf 100644 --- a/server/src/models/mod.rs +++ b/server/src/models/mod.rs @@ -1,7 +1,14 @@ +//! Core data structures shared across the server. +//! +//! This file intentionally mixes database row types, HTTP payloads, WebSocket +//! payloads, and a few helper functions so the rest of the codebase can import +//! common shapes from one place. + use serde::{Deserialize, Serialize}; // ── Database models ── +/// Row from the `users` table. #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] pub struct User { pub id: String, @@ -11,6 +18,7 @@ pub struct User { pub created_at: String, } +/// Row from the `rooms` table. #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] pub struct Room { pub id: String, @@ -24,6 +32,7 @@ pub struct Room { pub deleted_at: Option, } +/// Row from the `messages` table. #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] pub struct Message { pub id: String, @@ -38,6 +47,7 @@ pub struct Message { pub hash: Option, } +/// Row from the `invites` table. #[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] pub struct Invite { pub id: String, @@ -51,6 +61,7 @@ pub struct Invite { // ── API request/response types ── +/// JSON body expected by the registration endpoint. #[derive(Debug, Deserialize)] pub struct RegisterRequest { pub email: String, @@ -58,18 +69,21 @@ pub struct RegisterRequest { pub display_name: String, } +/// JSON body expected by the login endpoint. #[derive(Debug, Deserialize)] pub struct LoginRequest { pub email: String, pub password: String, } +/// Standard auth response returned after login, registration, or profile update. #[derive(Debug, Serialize)] pub struct AuthResponse { pub token: String, pub user: UserPublic, } +/// Public user data safe to return to any authenticated client. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UserPublic { pub id: String, @@ -81,6 +95,7 @@ pub struct UserPublic { pub nostr_pubkey: Option, } +/// JSON body used when a user creates a new chat room. #[derive(Debug, Deserialize)] pub struct CreateRoomRequest { pub name: String, @@ -93,10 +108,10 @@ pub struct CreateRoomRequest { pub ai_name: String, } +/// Pick a friendly default AI display name when the creator does not specify one. fn default_ai_name() -> String { let names = [ - "Nova", "Atlas", "Sage", "Echo", "Pixel", - "Cosmo", "Ember", "Flux", "Lyra", "Onyx", + "Nova", "Atlas", "Sage", "Echo", "Pixel", "Cosmo", "Ember", "Flux", "Lyra", "Onyx", ]; let idx = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -105,10 +120,12 @@ fn default_ai_name() -> String { names[idx].to_string() } +/// Default prompt that defines the AI assistant's behavior inside a room. fn default_system_prompt() -> String { "You are a helpful AI assistant participating in a group chat. Be conversational, helpful, and concise. You can see messages from all participants. When mentioned with @ai, respond helpfully.\n\nYou have access to tools:\n- **web_search**: Search the web for current information. Use this when asked about recent events, news, facts you're unsure about, or anything that needs up-to-date information.\n- **web_fetch**: Fetch and read the content of a web page. Use this when a user shares a URL and wants you to read/summarize it, or when you need more details from a search result.\n\nUse tools proactively when they would help answer the question better. You don't need to ask permission to use them.".to_string() } +/// Full room payload returned to the client, including current members. #[derive(Debug, Serialize)] pub struct RoomResponse { pub id: String, @@ -122,6 +139,7 @@ pub struct RoomResponse { pub members: Vec, } +/// JSON body for an email-based room invite. #[derive(Debug, Deserialize)] pub struct CreateInviteRequest { pub room_id: String, @@ -130,6 +148,7 @@ pub struct CreateInviteRequest { // ── WebSocket event types ── +/// Messages the browser can send over the WebSocket connection. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum WsClientMessage { @@ -148,17 +167,14 @@ pub enum WsClientMessage { Typing { room_id: String }, } +/// Messages the server can push to browsers over the WebSocket connection. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum WsServerMessage { #[serde(rename = "new_message")] - NewMessage { - message: MessagePayload, - }, + NewMessage { message: MessagePayload }, #[serde(rename = "ai_typing")] - AiTyping { - room_id: String, - }, + AiTyping { room_id: String }, #[serde(rename = "user_typing")] UserTyping { room_id: String, @@ -166,21 +182,13 @@ pub enum WsServerMessage { display_name: String, }, #[serde(rename = "error")] - Error { - message: String, - }, + Error { message: String }, #[serde(rename = "joined")] - Joined { - room_id: String, - }, + Joined { room_id: String }, #[serde(rename = "room_deleted")] - RoomDeleted { - room_id: String, - }, + RoomDeleted { room_id: String }, #[serde(rename = "room_cleared")] - RoomCleared { - room_id: String, - }, + RoomCleared { room_id: String }, #[serde(rename = "ai_tool_usage")] AiToolUsage { room_id: String, @@ -194,12 +202,10 @@ pub enum WsServerMessage { delta: String, }, #[serde(rename = "ai_stream_end")] - AiStreamEnd { - room_id: String, - message_id: String, - }, + AiStreamEnd { room_id: String, message_id: String }, } +/// Message shape sent to clients for history loading and live updates. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct MessagePayload { pub id: String, @@ -224,7 +230,7 @@ pub struct MessagePayload { /// Compute Gravatar-compatible MD5 hash from an email address. pub fn gravatar_hash(email: &str) -> String { - use md5::{Md5, Digest}; + use md5::{Digest, Md5}; let normalized = email.trim().to_lowercase(); let result = Md5::digest(normalized.as_bytes()); format!("{:x}", result) @@ -232,13 +238,14 @@ pub fn gravatar_hash(email: &str) -> String { /// Compute SHA-256 integrity hash from created_at timestamp + message content. pub fn message_hash(created_at: &str, content: &str) -> String { - use sha2::{Sha256, Digest}; + use sha2::{Digest, Sha256}; let mut hasher = Sha256::new(); hasher.update(created_at.as_bytes()); hasher.update(content.as_bytes()); format!("{:x}", hasher.finalize()) } +/// Usage and tool metadata captured for AI-generated messages. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AiMeta { pub model: String, @@ -250,6 +257,7 @@ pub struct AiMeta { pub tool_results: Option>, } +/// One tool invocation performed while generating an AI answer. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolResult { pub tool: String, @@ -259,6 +267,7 @@ pub struct ToolResult { // ── Broadcast event (internal channel) ── +/// Internal event sent through a Tokio broadcast channel to WebSocket tasks. #[derive(Debug, Clone)] pub struct BroadcastEvent { pub room_id: String, @@ -267,9 +276,10 @@ pub struct BroadcastEvent { // ── JWT Claims ── +/// Claims stored inside the server-issued JWT. #[derive(Debug, Serialize, Deserialize)] pub struct Claims { - pub sub: String, // user_id + pub sub: String, // user_id pub email: String, pub display_name: String, pub exp: usize, @@ -277,6 +287,7 @@ pub struct Claims { // ── Pagination ── +/// Common pagination parameters for message history endpoints. #[derive(Debug, Deserialize)] pub struct PaginationParams { #[serde(default = "default_limit")] @@ -288,7 +299,7 @@ fn default_limit() -> i64 { 50 } -/// Returns "" if the email is a sentinel nostr: value, otherwise returns it as-is. +/// Hide placeholder `nostr:*` emails from normal client responses. pub fn public_email(email: &str) -> String { if email.starts_with("nostr:") { String::new() @@ -299,11 +310,13 @@ pub fn public_email(email: &str) -> String { // ── Nostr auth types ── +/// Response returned by the Nostr challenge endpoint. #[derive(Debug, Serialize)] pub struct NostrChallengeResponse { pub challenge: String, } +/// JSON body sent by the client when proving Nostr ownership. #[derive(Debug, Deserialize)] pub struct NostrVerifyRequest { pub signed_event: String, @@ -312,6 +325,7 @@ pub struct NostrVerifyRequest { pub profile_picture: Option, } +/// JSON body for inviting an already-known Nostr user into a room. #[derive(Debug, Deserialize)] pub struct NostrInviteRequest { pub room_id: String, diff --git a/server/src/services/brave.rs b/server/src/services/brave.rs index a949cb4..c2eceea 100644 --- a/server/src/services/brave.rs +++ b/server/src/services/brave.rs @@ -4,6 +4,7 @@ use crate::services::search::SearchResult; const BRAVE_SEARCH_URL: &str = "https://api.search.brave.com/res/v1/web/search"; +/// Partial Brave API response containing only the fields this app needs. #[derive(Debug, Deserialize)] struct BraveResponse { web: Option, @@ -27,11 +28,7 @@ struct BraveResult { /// Search the web using the Brave Search API. /// Returns a list of simplified search results. -pub async fn search( - query: &str, - api_key: &str, - count: u8, -) -> Result, String> { +pub async fn search(query: &str, api_key: &str, count: u8) -> Result, String> { let count = count.clamp(1, 10); let client = reqwest::Client::new(); diff --git a/server/src/services/fetch.rs b/server/src/services/fetch.rs index 0227156..81ff54e 100644 --- a/server/src/services/fetch.rs +++ b/server/src/services/fetch.rs @@ -19,9 +19,29 @@ const STRIP_TAGS: &[&str] = &[ /// Block-level tags that should produce newlines in text output. const BLOCK_TAGS: &[&str] = &[ - "p", "div", "h1", "h2", "h3", "h4", "h5", "h6", "li", "br", "tr", - "blockquote", "pre", "section", "article", "main", "header", - "dt", "dd", "figcaption", "table", "thead", "tbody", + "p", + "div", + "h1", + "h2", + "h3", + "h4", + "h5", + "h6", + "li", + "br", + "tr", + "blockquote", + "pre", + "section", + "article", + "main", + "header", + "dt", + "dd", + "figcaption", + "table", + "thead", + "tbody", ]; /// Fetch a URL and extract its text content. diff --git a/server/src/services/mod.rs b/server/src/services/mod.rs index f382938..07cb773 100644 --- a/server/src/services/mod.rs +++ b/server/src/services/mod.rs @@ -1,3 +1,9 @@ +//! Integrations with external systems used by the chat server. +//! +//! These modules wrap search providers, web page fetching, and the OpenRouter +//! chat completion API so the rest of the application can call them with simple +//! Rust types. + pub mod brave; pub mod fetch; pub mod openrouter; diff --git a/server/src/services/openrouter.rs b/server/src/services/openrouter.rs index c78cc98..b10e7fd 100644 --- a/server/src/services/openrouter.rs +++ b/server/src/services/openrouter.rs @@ -235,7 +235,9 @@ pub async fn chat_completion_stream( { Ok(r) => r, Err(e) => { - let _ = tx.send(StreamEvent::Error(format!("Request failed: {}", e))).await; + let _ = tx + .send(StreamEvent::Error(format!("Request failed: {}", e))) + .await; return; } }; @@ -243,7 +245,12 @@ pub async fn chat_completion_stream( if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); - let _ = tx.send(StreamEvent::Error(format!("OpenRouter error {}: {}", status, body))).await; + let _ = tx + .send(StreamEvent::Error(format!( + "OpenRouter error {}: {}", + status, body + ))) + .await; return; } @@ -264,7 +271,9 @@ pub async fn chat_completion_stream( let bytes = match chunk_result { Ok(b) => b, Err(e) => { - let _ = tx.send(StreamEvent::Error(format!("Stream error: {}", e))).await; + let _ = tx + .send(StreamEvent::Error(format!("Stream error: {}", e))) + .await; return; } }; @@ -338,7 +347,10 @@ pub async fn chat_completion_stream( tool_call_accum[idx].function.name.push_str(name); } if let Some(args) = &func.arguments { - tool_call_accum[idx].function.arguments.push_str(args); + tool_call_accum[idx] + .function + .arguments + .push_str(args); } } } @@ -373,7 +385,11 @@ pub async fn chat_completion_stream( // AI requested tool calls let assistant_msg = ChatMessage { role: "assistant".into(), - content: if full_content.is_empty() { None } else { Some(Content::Text(full_content)) }, + content: if full_content.is_empty() { + None + } else { + Some(Content::Text(full_content)) + }, tool_calls: Some(tool_call_accum), tool_call_id: None, }; @@ -420,7 +436,9 @@ pub fn build_chat_history( Content::Parts(vec![ ContentPart::Text { text }, ContentPart::ImageUrl { - image_url: ImageUrlData { url: data_url.clone() }, + image_url: ImageUrlData { + url: data_url.clone(), + }, }, ]) } else { diff --git a/server/src/services/search.rs b/server/src/services/search.rs index fc351e1..ed62969 100644 --- a/server/src/services/search.rs +++ b/server/src/services/search.rs @@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize}; use super::{brave, tavily}; +/// Which search backend the AI tool layer should call. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SearchProvider { Tavily, @@ -9,8 +10,14 @@ pub enum SearchProvider { } impl SearchProvider { + /// Parse the `SEARCH_PROVIDER` environment variable into a supported variant. pub fn from_env(value: Option<&str>) -> Result { - match value.unwrap_or("tavily").trim().to_ascii_lowercase().as_str() { + match value + .unwrap_or("tavily") + .trim() + .to_ascii_lowercase() + .as_str() + { "tavily" => Ok(Self::Tavily), "brave" => Ok(Self::Brave), other => Err(format!( @@ -20,6 +27,7 @@ impl SearchProvider { } } + /// Return the environment variable name required by the selected provider. pub fn required_key_name(self) -> &'static str { match self { Self::Tavily => "TAVILY_API_KEY", @@ -28,6 +36,7 @@ impl SearchProvider { } } +/// Normalized search result shape shared across providers. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchResult { pub title: String, @@ -36,6 +45,7 @@ pub struct SearchResult { pub age: Option, } +/// Dispatch a search request to whichever provider the server is configured to use. pub async fn search( provider: SearchProvider, query: &str, @@ -59,6 +69,7 @@ pub async fn search( } } +/// Turn search results into plain text the AI model can read as tool output. pub fn format_results(results: &[SearchResult]) -> String { if results.is_empty() { return "No search results found.".to_string(); diff --git a/server/src/services/tavily.rs b/server/src/services/tavily.rs index e516bfd..ff2102f 100644 --- a/server/src/services/tavily.rs +++ b/server/src/services/tavily.rs @@ -23,11 +23,7 @@ struct TavilyResult { published_date: Option, } -pub async fn search( - query: &str, - api_key: &str, - count: u8, -) -> Result, String> { +pub async fn search(query: &str, api_key: &str, count: u8) -> Result, String> { let max_results = count.clamp(1, 10); let client = reqwest::Client::new(); @@ -75,7 +71,10 @@ pub async fn search( if first.description.is_empty() { first.description = format!("AI summary: {}", answer); } else { - first.description = format!("AI summary: {}\nSource excerpt: {}", answer, first.description); + first.description = format!( + "AI summary: {}\nSource excerpt: {}", + answer, first.description + ); } } else { results.push(SearchResult {