Replaces batch AI responses with real-time SSE streaming from OpenRouter. Tokens are buffered client-side and drained via requestAnimationFrame for a smooth typing effect instead of choppy chunk dumps. Backend: - Rewrite openrouter service for SSE streaming with incremental tool call accumulation - Add AiStreamChunk/AiStreamEnd WebSocket event types - Stream content deltas to clients during all tool call rounds - Increase broadcast channel capacity (256 -> 4096) and handle Lagged errors gracefully Frontend: - Add StreamBuffer utility with adaptive rAF-based character draining - Show streaming message-bubble with blinking cursor during generation - Clean up buffer on room switch and final message replacement Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
464 lines
16 KiB
Rust
464 lines
16 KiB
Rust
use axum::{
|
|
extract::{
|
|
ws::{Message, WebSocket},
|
|
Query, State, WebSocketUpgrade,
|
|
},
|
|
response::IntoResponse,
|
|
};
|
|
use futures::{SinkExt, StreamExt};
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{
|
|
middleware::auth::decode_token,
|
|
models::{BroadcastEvent, MessagePayload, WsClientMessage, WsServerMessage},
|
|
services::{brave, fetch, openrouter},
|
|
AppState,
|
|
};
|
|
|
|
/// Maximum number of tool call rounds before forcing a text response.
|
|
const MAX_TOOL_ROUNDS: usize = 5;
|
|
|
|
#[derive(serde::Deserialize)]
|
|
pub struct WsQuery {
|
|
token: String,
|
|
}
|
|
|
|
pub async fn ws_handler(
|
|
ws: WebSocketUpgrade,
|
|
State(state): State<Arc<AppState>>,
|
|
Query(query): Query<WsQuery>,
|
|
) -> impl IntoResponse {
|
|
// Authenticate before upgrading
|
|
let claims = match decode_token(&query.token, &state.jwt_secret) {
|
|
Ok(c) => c,
|
|
Err(_) => {
|
|
return axum::http::StatusCode::UNAUTHORIZED.into_response();
|
|
}
|
|
};
|
|
|
|
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.sub, claims.display_name))
|
|
}
|
|
|
|
async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String, display_name: String) {
|
|
let (mut ws_tx, mut ws_rx) = socket.split();
|
|
let mut broadcast_rx = state.tx.subscribe();
|
|
|
|
// Track which rooms this user is watching
|
|
let subscribed_rooms: Arc<tokio::sync::Mutex<std::collections::HashSet<String>>> =
|
|
Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new()));
|
|
|
|
let rooms_clone = subscribed_rooms.clone();
|
|
|
|
// Task: forward broadcast events to this client
|
|
let mut send_task = tokio::spawn(async move {
|
|
loop {
|
|
match broadcast_rx.recv().await {
|
|
Ok(event) => {
|
|
let rooms = rooms_clone.lock().await;
|
|
if rooms.contains(&event.room_id) {
|
|
let msg = serde_json::to_string(&event.message).unwrap();
|
|
if ws_tx.send(Message::Text(msg.into())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
|
tracing::warn!("WS subscriber lagged, skipped {} messages", n);
|
|
// Continue receiving — don't drop the connection
|
|
continue;
|
|
}
|
|
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let state_clone = state.clone();
|
|
let user_id_clone = user_id.clone();
|
|
let display_name_clone = display_name.clone();
|
|
let rooms_clone2 = subscribed_rooms.clone();
|
|
|
|
// Task: receive messages from client
|
|
let mut recv_task = tokio::spawn(async move {
|
|
while let Some(Ok(msg)) = ws_rx.next().await {
|
|
let text = match msg {
|
|
Message::Text(t) => t.to_string(),
|
|
Message::Close(_) => break,
|
|
_ => continue,
|
|
};
|
|
|
|
let client_msg: WsClientMessage = match serde_json::from_str(&text) {
|
|
Ok(m) => m,
|
|
Err(e) => {
|
|
let _ = state_clone.tx.send(BroadcastEvent {
|
|
room_id: String::new(),
|
|
message: WsServerMessage::Error {
|
|
message: format!("Invalid message: {}", e),
|
|
},
|
|
});
|
|
continue;
|
|
}
|
|
};
|
|
|
|
match client_msg {
|
|
WsClientMessage::JoinRoom { room_id } => {
|
|
tracing::info!("User {} joined room {}", user_id_clone, room_id);
|
|
rooms_clone2.lock().await.insert(room_id.clone());
|
|
}
|
|
WsClientMessage::Typing { room_id } => {
|
|
let _ = state_clone.tx.send(BroadcastEvent {
|
|
room_id: room_id.clone(),
|
|
message: WsServerMessage::UserTyping {
|
|
room_id,
|
|
user_id: user_id_clone.clone(),
|
|
display_name: display_name_clone.clone(),
|
|
},
|
|
});
|
|
}
|
|
WsClientMessage::SendMessage {
|
|
room_id,
|
|
content,
|
|
mentions,
|
|
} => {
|
|
tracing::info!("User {} sending message to room {}", user_id_clone, room_id);
|
|
handle_send_message(
|
|
&state_clone,
|
|
&user_id_clone,
|
|
&display_name_clone,
|
|
&room_id,
|
|
&content,
|
|
&mentions,
|
|
)
|
|
.await;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Wait for either task to finish, then abort the other
|
|
tokio::select! {
|
|
_ = &mut send_task => recv_task.abort(),
|
|
_ = &mut recv_task => send_task.abort(),
|
|
}
|
|
|
|
tracing::info!("WebSocket disconnected: {}", user_id);
|
|
}
|
|
|
|
async fn handle_send_message(
|
|
state: &Arc<AppState>,
|
|
user_id: &str,
|
|
display_name: &str,
|
|
room_id: &str,
|
|
content: &str,
|
|
mentions: &[String],
|
|
) {
|
|
let msg_id = Uuid::new_v4().to_string();
|
|
let mentions_json = serde_json::to_string(mentions).unwrap_or_else(|_| "[]".to_string());
|
|
let now = chrono::Utc::now().to_rfc3339();
|
|
|
|
// Store in database
|
|
let _ = sqlx::query(
|
|
"INSERT INTO messages (id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at) VALUES (?, ?, ?, ?, ?, ?, 0, ?)",
|
|
)
|
|
.bind(&msg_id)
|
|
.bind(room_id)
|
|
.bind(user_id)
|
|
.bind(display_name)
|
|
.bind(content)
|
|
.bind(&mentions_json)
|
|
.bind(&now)
|
|
.execute(&state.db)
|
|
.await;
|
|
|
|
// Broadcast human message
|
|
let payload = MessagePayload {
|
|
id: msg_id,
|
|
room_id: room_id.to_string(),
|
|
sender_id: user_id.to_string(),
|
|
sender_name: display_name.to_string(),
|
|
content: content.to_string(),
|
|
mentions: mentions.to_vec(),
|
|
is_ai: false,
|
|
created_at: now,
|
|
ai_meta: None,
|
|
};
|
|
|
|
let _ = state.tx.send(BroadcastEvent {
|
|
room_id: room_id.to_string(),
|
|
message: WsServerMessage::NewMessage {
|
|
message: payload,
|
|
},
|
|
});
|
|
|
|
// Check if AI should respond
|
|
let ai_user_id = "ai-assistant";
|
|
let should_respond = mentions.contains(&ai_user_id.to_string());
|
|
|
|
// Also check room settings for ai_always_respond
|
|
let room = sqlx::query_as::<_, (String, bool, String)>(
|
|
"SELECT model_id, ai_always_respond, system_prompt FROM rooms WHERE id = ? AND deleted_at IS NULL",
|
|
)
|
|
.bind(room_id)
|
|
.fetch_optional(&state.db)
|
|
.await;
|
|
|
|
let (model_id, always_respond, system_prompt) = match room {
|
|
Ok(Some(r)) => r,
|
|
_ => return,
|
|
};
|
|
|
|
if !should_respond && !always_respond {
|
|
return;
|
|
}
|
|
|
|
// Signal AI is typing
|
|
let _ = state.tx.send(BroadcastEvent {
|
|
room_id: room_id.to_string(),
|
|
message: WsServerMessage::AiTyping {
|
|
room_id: room_id.to_string(),
|
|
},
|
|
});
|
|
|
|
// Fetch recent history
|
|
let recent_messages = sqlx::query_as::<_, (String, String, bool)>(
|
|
"SELECT sender_name, content, is_ai FROM messages WHERE room_id = ? ORDER BY created_at DESC LIMIT 50",
|
|
)
|
|
.bind(room_id)
|
|
.fetch_all(&state.db)
|
|
.await
|
|
.unwrap_or_default();
|
|
|
|
let history: Vec<(String, String, bool)> = recent_messages.into_iter().rev().collect();
|
|
let mut chat_history = openrouter::build_chat_history(&system_prompt, &history);
|
|
|
|
// Build tools for AI
|
|
let tools = openrouter::build_tools();
|
|
|
|
// Pre-generate AI message ID so we can reference it in stream chunks
|
|
let ai_msg_id = Uuid::new_v4().to_string();
|
|
|
|
// Call OpenRouter with tool loop — uses streaming for all rounds
|
|
let mut total_prompt_tokens: u32 = 0;
|
|
let mut total_completion_tokens: u32 = 0;
|
|
let mut total_response_ms: u64 = 0;
|
|
let mut final_model = model_id.clone();
|
|
let mut ai_response = String::new();
|
|
let mut had_error = false;
|
|
let mut collected_tool_results: Vec<crate::models::ToolResult> = vec![];
|
|
|
|
'tool_loop: for round in 0..MAX_TOOL_ROUNDS {
|
|
let mut stream_rx = openrouter::chat_completion_stream(
|
|
chat_history.clone(),
|
|
&model_id,
|
|
&state.openrouter_key,
|
|
Some(tools.clone()),
|
|
)
|
|
.await;
|
|
|
|
while let Some(event) = stream_rx.recv().await {
|
|
match event {
|
|
openrouter::StreamEvent::Delta(text) => {
|
|
// Broadcast each content chunk to clients
|
|
let _ = state.tx.send(BroadcastEvent {
|
|
room_id: room_id.to_string(),
|
|
message: WsServerMessage::AiStreamChunk {
|
|
room_id: room_id.to_string(),
|
|
message_id: ai_msg_id.clone(),
|
|
delta: text.clone(),
|
|
},
|
|
});
|
|
ai_response.push_str(&text);
|
|
}
|
|
openrouter::StreamEvent::ToolCalls(assistant_msg, stats) => {
|
|
total_prompt_tokens += stats.prompt_tokens;
|
|
total_completion_tokens += stats.completion_tokens;
|
|
total_response_ms += stats.response_ms;
|
|
final_model = stats.model.clone();
|
|
|
|
tracing::info!(
|
|
"AI requesting tool calls (round {}): {:?}",
|
|
round + 1,
|
|
assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
|
|
);
|
|
|
|
// Add the assistant's tool-call message to history
|
|
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
|
|
chat_history.push(assistant_msg);
|
|
|
|
// Execute each tool call and add results
|
|
for tool_call in &tool_calls {
|
|
let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments);
|
|
|
|
// Broadcast real-time tool usage event
|
|
let _ = state.tx.send(BroadcastEvent {
|
|
room_id: room_id.to_string(),
|
|
message: WsServerMessage::AiToolUsage {
|
|
room_id: room_id.to_string(),
|
|
tool_name: tool_call.function.name.clone(),
|
|
status: "calling".to_string(),
|
|
},
|
|
});
|
|
|
|
let tool_result = execute_tool(
|
|
&tool_call.function.name,
|
|
&tool_call.function.arguments,
|
|
&state.brave_api_key,
|
|
)
|
|
.await;
|
|
|
|
tracing::info!(
|
|
"Tool {} result: {} chars",
|
|
tool_call.function.name,
|
|
tool_result.len()
|
|
);
|
|
|
|
collected_tool_results.push(crate::models::ToolResult {
|
|
tool: tool_call.function.name.clone(),
|
|
input: tool_input,
|
|
result: tool_result.clone(),
|
|
});
|
|
|
|
chat_history.push(openrouter::ChatMessage {
|
|
role: "tool".into(),
|
|
content: Some(tool_result),
|
|
tool_calls: None,
|
|
tool_call_id: Some(tool_call.id.clone()),
|
|
});
|
|
}
|
|
// Continue to next round (tool loop)
|
|
continue 'tool_loop;
|
|
}
|
|
openrouter::StreamEvent::Done(stats) => {
|
|
total_prompt_tokens += stats.prompt_tokens;
|
|
total_completion_tokens += stats.completion_tokens;
|
|
total_response_ms += stats.response_ms;
|
|
final_model = stats.model;
|
|
break 'tool_loop;
|
|
}
|
|
openrouter::StreamEvent::Error(e) => {
|
|
tracing::error!("OpenRouter stream error (round {}): {}", round + 1, e);
|
|
ai_response = format!("*Sorry, I encountered an error: {}*", e);
|
|
had_error = true;
|
|
break 'tool_loop;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// If we exhausted all rounds without a text response, note it
|
|
if ai_response.is_empty() && !had_error {
|
|
ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string();
|
|
}
|
|
|
|
// Signal stream end so client can finalize rendering
|
|
let _ = state.tx.send(BroadcastEvent {
|
|
room_id: room_id.to_string(),
|
|
message: WsServerMessage::AiStreamEnd {
|
|
room_id: room_id.to_string(),
|
|
message_id: ai_msg_id.clone(),
|
|
},
|
|
});
|
|
|
|
let ai_meta = if !had_error {
|
|
Some(crate::models::AiMeta {
|
|
model: final_model,
|
|
prompt_tokens: total_prompt_tokens,
|
|
completion_tokens: total_completion_tokens,
|
|
total_tokens: total_prompt_tokens + total_completion_tokens,
|
|
response_ms: total_response_ms,
|
|
tool_results: if collected_tool_results.is_empty() {
|
|
None
|
|
} else {
|
|
Some(collected_tool_results)
|
|
},
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Store AI response
|
|
let ai_now = chrono::Utc::now().to_rfc3339();
|
|
|
|
// Serialize ai_meta for database storage
|
|
let ai_meta_json = ai_meta.as_ref().and_then(|m| serde_json::to_string(m).ok());
|
|
|
|
let _ = sqlx::query(
|
|
"INSERT INTO messages (id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at, ai_meta) VALUES (?, ?, ?, ?, ?, '[]', 1, ?, ?)",
|
|
)
|
|
.bind(&ai_msg_id)
|
|
.bind(room_id)
|
|
.bind(ai_user_id)
|
|
.bind("AI Assistant")
|
|
.bind(&ai_response)
|
|
.bind(&ai_now)
|
|
.bind(&ai_meta_json)
|
|
.execute(&state.db)
|
|
.await;
|
|
|
|
// Broadcast final AI message (includes full content + ai_meta)
|
|
let ai_payload = MessagePayload {
|
|
id: ai_msg_id,
|
|
room_id: room_id.to_string(),
|
|
sender_id: ai_user_id.to_string(),
|
|
sender_name: "AI Assistant".to_string(),
|
|
content: ai_response,
|
|
mentions: vec![],
|
|
is_ai: true,
|
|
created_at: ai_now,
|
|
ai_meta,
|
|
};
|
|
|
|
let _ = state.tx.send(BroadcastEvent {
|
|
room_id: room_id.to_string(),
|
|
message: WsServerMessage::NewMessage {
|
|
message: ai_payload,
|
|
},
|
|
});
|
|
}
|
|
|
|
/// Extract a human-readable input string from tool arguments (for UI display).
|
|
fn extract_tool_input(tool_name: &str, arguments: &str) -> String {
|
|
let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
|
|
match tool_name {
|
|
"brave_search" => args["query"].as_str().unwrap_or("").to_string(),
|
|
"web_fetch" => args["url"].as_str().unwrap_or("").to_string(),
|
|
_ => arguments.to_string(),
|
|
}
|
|
}
|
|
|
|
/// Execute a tool call by name, returning the result as a string.
|
|
async fn execute_tool(name: &str, arguments: &str, brave_api_key: &str) -> String {
|
|
match name {
|
|
"brave_search" => {
|
|
let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
|
|
let query = args["query"].as_str().unwrap_or("").to_string();
|
|
let count = args["count"].as_u64().unwrap_or(5) as u8;
|
|
|
|
if query.is_empty() {
|
|
return "Error: search query is required".into();
|
|
}
|
|
|
|
match brave::search(&query, brave_api_key, count).await {
|
|
Ok(results) => brave::format_results(&results),
|
|
Err(e) => format!("Search error: {}", e),
|
|
}
|
|
}
|
|
"web_fetch" => {
|
|
let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
|
|
let url = args["url"].as_str().unwrap_or("").to_string();
|
|
|
|
if url.is_empty() {
|
|
return "Error: URL is required".into();
|
|
}
|
|
|
|
match fetch::fetch_url(&url, 8000).await {
|
|
Ok(result) => fetch::format_result(&result),
|
|
Err(e) => format!("Fetch error: {}", e),
|
|
}
|
|
}
|
|
_ => format!("Unknown tool: {}", name),
|
|
}
|
|
}
|