Jason Tudisco 4a002c85d4 feat: add streaming AI responses with smooth token-by-token rendering
Replaces batch AI responses with real-time SSE streaming from OpenRouter.
Tokens are buffered client-side and drained via requestAnimationFrame for
a smooth typing effect instead of choppy chunk dumps.

Backend:
- Rewrite openrouter service for SSE streaming with incremental tool call accumulation
- Add AiStreamChunk/AiStreamEnd WebSocket event types
- Stream content deltas to clients during all tool call rounds
- Increase broadcast channel capacity (256 -> 4096) and handle Lagged errors gracefully

Frontend:
- Add StreamBuffer utility with adaptive rAF-based character draining
- Show streaming message-bubble with blinking cursor during generation
- Clean up buffer on room switch and final message replacement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 20:23:49 -06:00

464 lines
16 KiB
Rust

use axum::{
extract::{
ws::{Message, WebSocket},
Query, State, WebSocketUpgrade,
},
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use std::sync::Arc;
use uuid::Uuid;
use crate::{
middleware::auth::decode_token,
models::{BroadcastEvent, MessagePayload, WsClientMessage, WsServerMessage},
services::{brave, fetch, openrouter},
AppState,
};
/// Maximum number of tool call rounds before forcing a text response.
const MAX_TOOL_ROUNDS: usize = 5;
#[derive(serde::Deserialize)]
pub struct WsQuery {
token: String,
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<Arc<AppState>>,
Query(query): Query<WsQuery>,
) -> impl IntoResponse {
// Authenticate before upgrading
let claims = match decode_token(&query.token, &state.jwt_secret) {
Ok(c) => c,
Err(_) => {
return axum::http::StatusCode::UNAUTHORIZED.into_response();
}
};
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.sub, claims.display_name))
}
async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String, display_name: String) {
let (mut ws_tx, mut ws_rx) = socket.split();
let mut broadcast_rx = state.tx.subscribe();
// Track which rooms this user is watching
let subscribed_rooms: Arc<tokio::sync::Mutex<std::collections::HashSet<String>>> =
Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new()));
let rooms_clone = subscribed_rooms.clone();
// Task: forward broadcast events to this client
let mut send_task = tokio::spawn(async move {
loop {
match broadcast_rx.recv().await {
Ok(event) => {
let rooms = rooms_clone.lock().await;
if rooms.contains(&event.room_id) {
let msg = serde_json::to_string(&event.message).unwrap();
if ws_tx.send(Message::Text(msg.into())).await.is_err() {
break;
}
}
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("WS subscriber lagged, skipped {} messages", n);
// Continue receiving — don't drop the connection
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
});
let state_clone = state.clone();
let user_id_clone = user_id.clone();
let display_name_clone = display_name.clone();
let rooms_clone2 = subscribed_rooms.clone();
// Task: receive messages from client
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(msg)) = ws_rx.next().await {
let text = match msg {
Message::Text(t) => t.to_string(),
Message::Close(_) => break,
_ => continue,
};
let client_msg: WsClientMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
let _ = state_clone.tx.send(BroadcastEvent {
room_id: String::new(),
message: WsServerMessage::Error {
message: format!("Invalid message: {}", e),
},
});
continue;
}
};
match client_msg {
WsClientMessage::JoinRoom { room_id } => {
tracing::info!("User {} joined room {}", user_id_clone, room_id);
rooms_clone2.lock().await.insert(room_id.clone());
}
WsClientMessage::Typing { room_id } => {
let _ = state_clone.tx.send(BroadcastEvent {
room_id: room_id.clone(),
message: WsServerMessage::UserTyping {
room_id,
user_id: user_id_clone.clone(),
display_name: display_name_clone.clone(),
},
});
}
WsClientMessage::SendMessage {
room_id,
content,
mentions,
} => {
tracing::info!("User {} sending message to room {}", user_id_clone, room_id);
handle_send_message(
&state_clone,
&user_id_clone,
&display_name_clone,
&room_id,
&content,
&mentions,
)
.await;
}
}
}
});
// Wait for either task to finish, then abort the other
tokio::select! {
_ = &mut send_task => recv_task.abort(),
_ = &mut recv_task => send_task.abort(),
}
tracing::info!("WebSocket disconnected: {}", user_id);
}
async fn handle_send_message(
state: &Arc<AppState>,
user_id: &str,
display_name: &str,
room_id: &str,
content: &str,
mentions: &[String],
) {
let msg_id = Uuid::new_v4().to_string();
let mentions_json = serde_json::to_string(mentions).unwrap_or_else(|_| "[]".to_string());
let now = chrono::Utc::now().to_rfc3339();
// Store in database
let _ = sqlx::query(
"INSERT INTO messages (id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at) VALUES (?, ?, ?, ?, ?, ?, 0, ?)",
)
.bind(&msg_id)
.bind(room_id)
.bind(user_id)
.bind(display_name)
.bind(content)
.bind(&mentions_json)
.bind(&now)
.execute(&state.db)
.await;
// Broadcast human message
let payload = MessagePayload {
id: msg_id,
room_id: room_id.to_string(),
sender_id: user_id.to_string(),
sender_name: display_name.to_string(),
content: content.to_string(),
mentions: mentions.to_vec(),
is_ai: false,
created_at: now,
ai_meta: None,
};
let _ = state.tx.send(BroadcastEvent {
room_id: room_id.to_string(),
message: WsServerMessage::NewMessage {
message: payload,
},
});
// Check if AI should respond
let ai_user_id = "ai-assistant";
let should_respond = mentions.contains(&ai_user_id.to_string());
// Also check room settings for ai_always_respond
let room = sqlx::query_as::<_, (String, bool, String)>(
"SELECT model_id, ai_always_respond, system_prompt FROM rooms WHERE id = ? AND deleted_at IS NULL",
)
.bind(room_id)
.fetch_optional(&state.db)
.await;
let (model_id, always_respond, system_prompt) = match room {
Ok(Some(r)) => r,
_ => return,
};
if !should_respond && !always_respond {
return;
}
// Signal AI is typing
let _ = state.tx.send(BroadcastEvent {
room_id: room_id.to_string(),
message: WsServerMessage::AiTyping {
room_id: room_id.to_string(),
},
});
// Fetch recent history
let recent_messages = sqlx::query_as::<_, (String, String, bool)>(
"SELECT sender_name, content, is_ai FROM messages WHERE room_id = ? ORDER BY created_at DESC LIMIT 50",
)
.bind(room_id)
.fetch_all(&state.db)
.await
.unwrap_or_default();
let history: Vec<(String, String, bool)> = recent_messages.into_iter().rev().collect();
let mut chat_history = openrouter::build_chat_history(&system_prompt, &history);
// Build tools for AI
let tools = openrouter::build_tools();
// Pre-generate AI message ID so we can reference it in stream chunks
let ai_msg_id = Uuid::new_v4().to_string();
// Call OpenRouter with tool loop — uses streaming for all rounds
let mut total_prompt_tokens: u32 = 0;
let mut total_completion_tokens: u32 = 0;
let mut total_response_ms: u64 = 0;
let mut final_model = model_id.clone();
let mut ai_response = String::new();
let mut had_error = false;
let mut collected_tool_results: Vec<crate::models::ToolResult> = vec![];
'tool_loop: for round in 0..MAX_TOOL_ROUNDS {
let mut stream_rx = openrouter::chat_completion_stream(
chat_history.clone(),
&model_id,
&state.openrouter_key,
Some(tools.clone()),
)
.await;
while let Some(event) = stream_rx.recv().await {
match event {
openrouter::StreamEvent::Delta(text) => {
// Broadcast each content chunk to clients
let _ = state.tx.send(BroadcastEvent {
room_id: room_id.to_string(),
message: WsServerMessage::AiStreamChunk {
room_id: room_id.to_string(),
message_id: ai_msg_id.clone(),
delta: text.clone(),
},
});
ai_response.push_str(&text);
}
openrouter::StreamEvent::ToolCalls(assistant_msg, stats) => {
total_prompt_tokens += stats.prompt_tokens;
total_completion_tokens += stats.completion_tokens;
total_response_ms += stats.response_ms;
final_model = stats.model.clone();
tracing::info!(
"AI requesting tool calls (round {}): {:?}",
round + 1,
assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
);
// Add the assistant's tool-call message to history
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
chat_history.push(assistant_msg);
// Execute each tool call and add results
for tool_call in &tool_calls {
let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments);
// Broadcast real-time tool usage event
let _ = state.tx.send(BroadcastEvent {
room_id: room_id.to_string(),
message: WsServerMessage::AiToolUsage {
room_id: room_id.to_string(),
tool_name: tool_call.function.name.clone(),
status: "calling".to_string(),
},
});
let tool_result = execute_tool(
&tool_call.function.name,
&tool_call.function.arguments,
&state.brave_api_key,
)
.await;
tracing::info!(
"Tool {} result: {} chars",
tool_call.function.name,
tool_result.len()
);
collected_tool_results.push(crate::models::ToolResult {
tool: tool_call.function.name.clone(),
input: tool_input,
result: tool_result.clone(),
});
chat_history.push(openrouter::ChatMessage {
role: "tool".into(),
content: Some(tool_result),
tool_calls: None,
tool_call_id: Some(tool_call.id.clone()),
});
}
// Continue to next round (tool loop)
continue 'tool_loop;
}
openrouter::StreamEvent::Done(stats) => {
total_prompt_tokens += stats.prompt_tokens;
total_completion_tokens += stats.completion_tokens;
total_response_ms += stats.response_ms;
final_model = stats.model;
break 'tool_loop;
}
openrouter::StreamEvent::Error(e) => {
tracing::error!("OpenRouter stream error (round {}): {}", round + 1, e);
ai_response = format!("*Sorry, I encountered an error: {}*", e);
had_error = true;
break 'tool_loop;
}
}
}
}
// If we exhausted all rounds without a text response, note it
if ai_response.is_empty() && !had_error {
ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string();
}
// Signal stream end so client can finalize rendering
let _ = state.tx.send(BroadcastEvent {
room_id: room_id.to_string(),
message: WsServerMessage::AiStreamEnd {
room_id: room_id.to_string(),
message_id: ai_msg_id.clone(),
},
});
let ai_meta = if !had_error {
Some(crate::models::AiMeta {
model: final_model,
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
response_ms: total_response_ms,
tool_results: if collected_tool_results.is_empty() {
None
} else {
Some(collected_tool_results)
},
})
} else {
None
};
// Store AI response
let ai_now = chrono::Utc::now().to_rfc3339();
// Serialize ai_meta for database storage
let ai_meta_json = ai_meta.as_ref().and_then(|m| serde_json::to_string(m).ok());
let _ = sqlx::query(
"INSERT INTO messages (id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at, ai_meta) VALUES (?, ?, ?, ?, ?, '[]', 1, ?, ?)",
)
.bind(&ai_msg_id)
.bind(room_id)
.bind(ai_user_id)
.bind("AI Assistant")
.bind(&ai_response)
.bind(&ai_now)
.bind(&ai_meta_json)
.execute(&state.db)
.await;
// Broadcast final AI message (includes full content + ai_meta)
let ai_payload = MessagePayload {
id: ai_msg_id,
room_id: room_id.to_string(),
sender_id: ai_user_id.to_string(),
sender_name: "AI Assistant".to_string(),
content: ai_response,
mentions: vec![],
is_ai: true,
created_at: ai_now,
ai_meta,
};
let _ = state.tx.send(BroadcastEvent {
room_id: room_id.to_string(),
message: WsServerMessage::NewMessage {
message: ai_payload,
},
});
}
/// Extract a human-readable input string from tool arguments (for UI display).
fn extract_tool_input(tool_name: &str, arguments: &str) -> String {
let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
match tool_name {
"brave_search" => args["query"].as_str().unwrap_or("").to_string(),
"web_fetch" => args["url"].as_str().unwrap_or("").to_string(),
_ => arguments.to_string(),
}
}
/// Execute a tool call by name, returning the result as a string.
async fn execute_tool(name: &str, arguments: &str, brave_api_key: &str) -> String {
match name {
"brave_search" => {
let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
let query = args["query"].as_str().unwrap_or("").to_string();
let count = args["count"].as_u64().unwrap_or(5) as u8;
if query.is_empty() {
return "Error: search query is required".into();
}
match brave::search(&query, brave_api_key, count).await {
Ok(results) => brave::format_results(&results),
Err(e) => format!("Search error: {}", e),
}
}
"web_fetch" => {
let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
let url = args["url"].as_str().unwrap_or("").to_string();
if url.is_empty() {
return "Error: URL is required".into();
}
match fetch::fetch_url(&url, 8000).await {
Ok(result) => fetch::format_result(&result),
Err(e) => format!("Fetch error: {}", e),
}
}
_ => format!("Unknown tool: {}", name),
}
}