use axum::{ extract::{ ws::{Message, WebSocket}, Query, State, WebSocketUpgrade, }, response::IntoResponse, }; use futures::{SinkExt, StreamExt}; use std::sync::Arc; use uuid::Uuid; use crate::{ middleware::auth::decode_token, models::{BroadcastEvent, MessagePayload, WsClientMessage, WsServerMessage}, services::{brave, fetch, openrouter}, AppState, }; /// Maximum number of tool call rounds before forcing a text response. const MAX_TOOL_ROUNDS: usize = 5; #[derive(serde::Deserialize)] pub struct WsQuery { token: String, } pub async fn ws_handler( ws: WebSocketUpgrade, State(state): State>, Query(query): Query, ) -> impl IntoResponse { // Authenticate before upgrading let claims = match decode_token(&query.token, &state.jwt_secret) { Ok(c) => c, Err(_) => { return axum::http::StatusCode::UNAUTHORIZED.into_response(); } }; ws.on_upgrade(move |socket| handle_socket(socket, state, claims.sub, claims.display_name)) } async fn handle_socket(socket: WebSocket, state: Arc, user_id: String, display_name: String) { let (mut ws_tx, mut ws_rx) = socket.split(); let mut broadcast_rx = state.tx.subscribe(); // Track which rooms this user is watching let subscribed_rooms: Arc>> = Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())); let rooms_clone = subscribed_rooms.clone(); // Task: forward broadcast events to this client let mut send_task = tokio::spawn(async move { loop { match broadcast_rx.recv().await { Ok(event) => { let rooms = rooms_clone.lock().await; if rooms.contains(&event.room_id) { let msg = serde_json::to_string(&event.message).unwrap(); if ws_tx.send(Message::Text(msg.into())).await.is_err() { break; } } } Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { tracing::warn!("WS subscriber lagged, skipped {} messages", n); // Continue receiving — don't drop the connection continue; } Err(tokio::sync::broadcast::error::RecvError::Closed) => { break; } } } }); let state_clone = state.clone(); let user_id_clone = user_id.clone(); let display_name_clone = display_name.clone(); let rooms_clone2 = subscribed_rooms.clone(); // Task: receive messages from client let mut recv_task = tokio::spawn(async move { while let Some(Ok(msg)) = ws_rx.next().await { let text = match msg { Message::Text(t) => t.to_string(), Message::Close(_) => break, _ => continue, }; let client_msg: WsClientMessage = match serde_json::from_str(&text) { Ok(m) => m, Err(e) => { let _ = state_clone.tx.send(BroadcastEvent { room_id: String::new(), message: WsServerMessage::Error { message: format!("Invalid message: {}", e), }, }); continue; } }; match client_msg { WsClientMessage::JoinRoom { room_id } => { tracing::info!("User {} joined room {}", user_id_clone, room_id); rooms_clone2.lock().await.insert(room_id.clone()); } WsClientMessage::Typing { room_id } => { let _ = state_clone.tx.send(BroadcastEvent { room_id: room_id.clone(), message: WsServerMessage::UserTyping { room_id, user_id: user_id_clone.clone(), display_name: display_name_clone.clone(), }, }); } WsClientMessage::SendMessage { room_id, content, mentions, } => { tracing::info!("User {} sending message to room {}", user_id_clone, room_id); handle_send_message( &state_clone, &user_id_clone, &display_name_clone, &room_id, &content, &mentions, ) .await; } } } }); // Wait for either task to finish, then abort the other tokio::select! { _ = &mut send_task => recv_task.abort(), _ = &mut recv_task => send_task.abort(), } tracing::info!("WebSocket disconnected: {}", user_id); } async fn handle_send_message( state: &Arc, user_id: &str, display_name: &str, room_id: &str, content: &str, mentions: &[String], ) { let msg_id = Uuid::new_v4().to_string(); let mentions_json = serde_json::to_string(mentions).unwrap_or_else(|_| "[]".to_string()); let now = chrono::Utc::now().to_rfc3339(); // Store in database let _ = sqlx::query( "INSERT INTO messages (id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at) VALUES (?, ?, ?, ?, ?, ?, 0, ?)", ) .bind(&msg_id) .bind(room_id) .bind(user_id) .bind(display_name) .bind(content) .bind(&mentions_json) .bind(&now) .execute(&state.db) .await; // Broadcast human message let payload = MessagePayload { id: msg_id, room_id: room_id.to_string(), sender_id: user_id.to_string(), sender_name: display_name.to_string(), content: content.to_string(), mentions: mentions.to_vec(), is_ai: false, created_at: now, ai_meta: None, }; let _ = state.tx.send(BroadcastEvent { room_id: room_id.to_string(), message: WsServerMessage::NewMessage { message: payload, }, }); // Check if AI should respond let ai_user_id = "ai-assistant"; let should_respond = mentions.contains(&ai_user_id.to_string()); // Also check room settings for ai_always_respond let room = sqlx::query_as::<_, (String, bool, String)>( "SELECT model_id, ai_always_respond, system_prompt FROM rooms WHERE id = ? AND deleted_at IS NULL", ) .bind(room_id) .fetch_optional(&state.db) .await; let (model_id, always_respond, system_prompt) = match room { Ok(Some(r)) => r, _ => return, }; if !should_respond && !always_respond { return; } // Signal AI is typing let _ = state.tx.send(BroadcastEvent { room_id: room_id.to_string(), message: WsServerMessage::AiTyping { room_id: room_id.to_string(), }, }); // Fetch recent history let recent_messages = sqlx::query_as::<_, (String, String, bool)>( "SELECT sender_name, content, is_ai FROM messages WHERE room_id = ? ORDER BY created_at DESC LIMIT 50", ) .bind(room_id) .fetch_all(&state.db) .await .unwrap_or_default(); let history: Vec<(String, String, bool)> = recent_messages.into_iter().rev().collect(); let mut chat_history = openrouter::build_chat_history(&system_prompt, &history); // Build tools for AI let tools = openrouter::build_tools(); // Pre-generate AI message ID so we can reference it in stream chunks let ai_msg_id = Uuid::new_v4().to_string(); // Call OpenRouter with tool loop — uses streaming for all rounds let mut total_prompt_tokens: u32 = 0; let mut total_completion_tokens: u32 = 0; let mut total_response_ms: u64 = 0; let mut final_model = model_id.clone(); let mut ai_response = String::new(); let mut had_error = false; let mut collected_tool_results: Vec = vec![]; 'tool_loop: for round in 0..MAX_TOOL_ROUNDS { let mut stream_rx = openrouter::chat_completion_stream( chat_history.clone(), &model_id, &state.openrouter_key, Some(tools.clone()), ) .await; while let Some(event) = stream_rx.recv().await { match event { openrouter::StreamEvent::Delta(text) => { // Broadcast each content chunk to clients let _ = state.tx.send(BroadcastEvent { room_id: room_id.to_string(), message: WsServerMessage::AiStreamChunk { room_id: room_id.to_string(), message_id: ai_msg_id.clone(), delta: text.clone(), }, }); ai_response.push_str(&text); } openrouter::StreamEvent::ToolCalls(assistant_msg, stats) => { total_prompt_tokens += stats.prompt_tokens; total_completion_tokens += stats.completion_tokens; total_response_ms += stats.response_ms; final_model = stats.model.clone(); tracing::info!( "AI requesting tool calls (round {}): {:?}", round + 1, assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::>()) ); // Add the assistant's tool-call message to history let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default(); chat_history.push(assistant_msg); // Execute each tool call and add results for tool_call in &tool_calls { let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments); // Broadcast real-time tool usage event let _ = state.tx.send(BroadcastEvent { room_id: room_id.to_string(), message: WsServerMessage::AiToolUsage { room_id: room_id.to_string(), tool_name: tool_call.function.name.clone(), status: "calling".to_string(), }, }); let tool_result = execute_tool( &tool_call.function.name, &tool_call.function.arguments, &state.brave_api_key, ) .await; tracing::info!( "Tool {} result: {} chars", tool_call.function.name, tool_result.len() ); collected_tool_results.push(crate::models::ToolResult { tool: tool_call.function.name.clone(), input: tool_input, result: tool_result.clone(), }); chat_history.push(openrouter::ChatMessage { role: "tool".into(), content: Some(tool_result), tool_calls: None, tool_call_id: Some(tool_call.id.clone()), }); } // Continue to next round (tool loop) continue 'tool_loop; } openrouter::StreamEvent::Done(stats) => { total_prompt_tokens += stats.prompt_tokens; total_completion_tokens += stats.completion_tokens; total_response_ms += stats.response_ms; final_model = stats.model; break 'tool_loop; } openrouter::StreamEvent::Error(e) => { tracing::error!("OpenRouter stream error (round {}): {}", round + 1, e); ai_response = format!("*Sorry, I encountered an error: {}*", e); had_error = true; break 'tool_loop; } } } } // If we exhausted all rounds without a text response, note it if ai_response.is_empty() && !had_error { ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string(); } // Signal stream end so client can finalize rendering let _ = state.tx.send(BroadcastEvent { room_id: room_id.to_string(), message: WsServerMessage::AiStreamEnd { room_id: room_id.to_string(), message_id: ai_msg_id.clone(), }, }); let ai_meta = if !had_error { Some(crate::models::AiMeta { model: final_model, prompt_tokens: total_prompt_tokens, completion_tokens: total_completion_tokens, total_tokens: total_prompt_tokens + total_completion_tokens, response_ms: total_response_ms, tool_results: if collected_tool_results.is_empty() { None } else { Some(collected_tool_results) }, }) } else { None }; // Store AI response let ai_now = chrono::Utc::now().to_rfc3339(); // Serialize ai_meta for database storage let ai_meta_json = ai_meta.as_ref().and_then(|m| serde_json::to_string(m).ok()); let _ = sqlx::query( "INSERT INTO messages (id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at, ai_meta) VALUES (?, ?, ?, ?, ?, '[]', 1, ?, ?)", ) .bind(&ai_msg_id) .bind(room_id) .bind(ai_user_id) .bind("AI Assistant") .bind(&ai_response) .bind(&ai_now) .bind(&ai_meta_json) .execute(&state.db) .await; // Broadcast final AI message (includes full content + ai_meta) let ai_payload = MessagePayload { id: ai_msg_id, room_id: room_id.to_string(), sender_id: ai_user_id.to_string(), sender_name: "AI Assistant".to_string(), content: ai_response, mentions: vec![], is_ai: true, created_at: ai_now, ai_meta, }; let _ = state.tx.send(BroadcastEvent { room_id: room_id.to_string(), message: WsServerMessage::NewMessage { message: ai_payload, }, }); } /// Extract a human-readable input string from tool arguments (for UI display). fn extract_tool_input(tool_name: &str, arguments: &str) -> String { let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default(); match tool_name { "brave_search" => args["query"].as_str().unwrap_or("").to_string(), "web_fetch" => args["url"].as_str().unwrap_or("").to_string(), _ => arguments.to_string(), } } /// Execute a tool call by name, returning the result as a string. async fn execute_tool(name: &str, arguments: &str, brave_api_key: &str) -> String { match name { "brave_search" => { let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default(); let query = args["query"].as_str().unwrap_or("").to_string(); let count = args["count"].as_u64().unwrap_or(5) as u8; if query.is_empty() { return "Error: search query is required".into(); } match brave::search(&query, brave_api_key, count).await { Ok(results) => brave::format_results(&results), Err(e) => format!("Search error: {}", e), } } "web_fetch" => { let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default(); let url = args["url"].as_str().unwrap_or("").to_string(); if url.is_empty() { return "Error: URL is required".into(); } match fetch::fetch_url(&url, 8000).await { Ok(result) => fetch::format_result(&result), Err(e) => format!("Fetch error: {}", e), } } _ => format!("Unknown tool: {}", name), } }