groupchat/server/src/services/openrouter.rs
Jason Tudisco 4a002c85d4 feat: add streaming AI responses with smooth token-by-token rendering
Replaces batch AI responses with real-time SSE streaming from OpenRouter.
Tokens are buffered client-side and drained via requestAnimationFrame for
a smooth typing effect instead of choppy chunk dumps.

Backend:
- Rewrite openrouter service for SSE streaming with incremental tool call accumulation
- Add AiStreamChunk/AiStreamEnd WebSocket event types
- Stream content deltas to clients during all tool call rounds
- Increase broadcast channel capacity (256 -> 4096) and handle Lagged errors gracefully

Frontend:
- Add StreamBuffer utility with adaptive rAF-based character draining
- Show streaming message-bubble with blinking cursor during generation
- Clean up buffer on room switch and final message replacement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 20:23:49 -06:00

400 lines
13 KiB
Rust

use futures::StreamExt;
use serde::{Deserialize, Serialize};
const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
// ── Request types ──
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
// ── Tool definition types ──
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Tool {
pub r#type: String,
pub function: ToolFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: ToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: String,
}
// ── Response types ──
#[derive(Debug, Deserialize)]
struct Usage {
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
total_tokens: Option<u32>,
}
// ── Streaming response types ──
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Option<Vec<StreamChoice>>,
model: Option<String>,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamChoice {
delta: Option<StreamDelta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StreamDelta {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<StreamToolCall>>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamToolCall {
index: Option<usize>,
id: Option<String>,
#[serde(default)]
r#type: Option<String>,
function: Option<StreamToolCallFunction>,
}
#[derive(Debug, Deserialize)]
struct StreamToolCallFunction {
name: Option<String>,
arguments: Option<String>,
}
/// Events emitted during a streaming completion.
#[derive(Debug)]
pub enum StreamEvent {
/// A chunk of content text.
Delta(String),
/// AI wants to call tools (streaming accumulated the full tool_calls).
ToolCalls(ChatMessage, CompletionStats),
/// Stream finished with final stats.
Done(CompletionStats),
/// An error occurred.
Error(String),
}
/// Stats returned alongside an AI completion.
#[derive(Debug, Clone, Serialize)]
pub struct CompletionStats {
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub response_ms: u64,
}
/// Build the tool definitions for brave_search and web_fetch.
pub fn build_tools() -> Vec<Tool> {
vec![
Tool {
r#type: "function".into(),
function: ToolFunction {
name: "brave_search".into(),
description: "Search the web for current information. Use this when users ask about recent events, need factual data you're unsure about, or want up-to-date information.".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"count": {
"type": "integer",
"description": "Number of results (1-10, default 5)"
}
},
"required": ["query"]
}),
},
},
Tool {
r#type: "function".into(),
function: ToolFunction {
name: "web_fetch".into(),
description: "Fetch and read the content of a web page. Use this to read articles, documentation, or any URL shared by users or found in search results.".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch"
}
},
"required": ["url"]
}),
},
},
]
}
/// Send a streaming chat completion request to OpenRouter.
/// Sends events via the returned mpsc receiver:
/// - `StreamEvent::Delta(text)` for each content chunk
/// - `StreamEvent::ToolCalls(msg, stats)` if the AI wants to call tools
/// - `StreamEvent::Done(stats)` when the stream finishes with a text response
/// - `StreamEvent::Error(msg)` on error
pub async fn chat_completion_stream(
history: Vec<ChatMessage>,
model_id: &str,
api_key: &str,
tools: Option<Vec<Tool>>,
) -> tokio::sync::mpsc::Receiver<StreamEvent> {
let (tx, rx) = tokio::sync::mpsc::channel::<StreamEvent>(256);
let model_id = model_id.to_string();
let api_key = api_key.to_string();
tokio::spawn(async move {
let client = reqwest::Client::new();
let request_body = ChatRequest {
model: model_id.clone(),
messages: history,
max_tokens: Some(2048),
tools,
stream: Some(true),
};
let start = std::time::Instant::now();
let response = match client
.post(OPENROUTER_API_URL)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.header("HTTP-Referer", "http://localhost:3001")
.header("X-Title", "GroupChat")
.json(&request_body)
.send()
.await
{
Ok(r) => r,
Err(e) => {
let _ = tx.send(StreamEvent::Error(format!("Request failed: {}", e))).await;
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let _ = tx.send(StreamEvent::Error(format!("OpenRouter error {}: {}", status, body))).await;
return;
}
// Read SSE stream
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
let mut full_content = String::new();
let mut model_name = model_id.clone();
// Accumulators for streamed tool calls
let mut tool_call_accum: Vec<ToolCall> = Vec::new();
let mut has_tool_calls = false;
// Usage from the final chunk (some providers include it)
let mut final_usage: Option<Usage> = None;
while let Some(chunk_result) = byte_stream.next().await {
let bytes = match chunk_result {
Ok(b) => b,
Err(e) => {
let _ = tx.send(StreamEvent::Error(format!("Stream error: {}", e))).await;
return;
}
};
buffer.push_str(&String::from_utf8_lossy(&bytes));
// Process complete SSE lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer = buffer[line_end + 1..].to_string();
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
let data = data.trim();
if data == "[DONE]" {
// Stream finished
continue;
}
let chunk: StreamChunk = match serde_json::from_str(data) {
Ok(c) => c,
Err(_) => continue,
};
if let Some(m) = &chunk.model {
model_name = m.clone();
}
if let Some(u) = chunk.usage {
final_usage = Some(u);
}
if let Some(choices) = &chunk.choices {
if let Some(choice) = choices.first() {
if let Some(delta) = &choice.delta {
// Handle content delta
if let Some(content) = &delta.content {
if !content.is_empty() {
full_content.push_str(content);
let _ = tx.send(StreamEvent::Delta(content.clone())).await;
}
}
// Handle tool call deltas (accumulate incrementally)
if let Some(tcs) = &delta.tool_calls {
has_tool_calls = true;
for tc in tcs {
let idx = tc.index.unwrap_or(0);
// Ensure we have enough slots
while tool_call_accum.len() <= idx {
tool_call_accum.push(ToolCall {
id: String::new(),
r#type: "function".to_string(),
function: ToolCallFunction {
name: String::new(),
arguments: String::new(),
},
});
}
if let Some(id) = &tc.id {
tool_call_accum[idx].id = id.clone();
}
if let Some(func) = &tc.function {
if let Some(name) = &func.name {
tool_call_accum[idx].function.name.push_str(name);
}
if let Some(args) = &func.arguments {
tool_call_accum[idx].function.arguments.push_str(args);
}
}
}
}
}
}
}
}
}
}
let elapsed_ms = start.elapsed().as_millis() as u64;
let (prompt_tokens, completion_tokens, total_tokens) = match &final_usage {
Some(u) => (
u.prompt_tokens.unwrap_or(0),
u.completion_tokens.unwrap_or(0),
u.total_tokens.unwrap_or(0),
),
None => (0, 0, 0),
};
let stats = CompletionStats {
model: model_name,
prompt_tokens,
completion_tokens,
total_tokens,
response_ms: elapsed_ms,
};
if has_tool_calls && !tool_call_accum.is_empty() {
// AI requested tool calls
let assistant_msg = ChatMessage {
role: "assistant".into(),
content: if full_content.is_empty() { None } else { Some(full_content) },
tool_calls: Some(tool_call_accum),
tool_call_id: None,
};
let _ = tx.send(StreamEvent::ToolCalls(assistant_msg, stats)).await;
} else {
// Normal text response completed
let _ = tx.send(StreamEvent::Done(stats)).await;
}
});
rx
}
/// Build the message history for OpenRouter from stored messages.
/// Includes the system prompt as the first message.
pub fn build_chat_history(
system_prompt: &str,
messages: &[(String, String, bool)], // (sender_name, content, is_ai)
) -> Vec<ChatMessage> {
let mut history = vec![ChatMessage {
role: "system".to_string(),
content: Some(system_prompt.to_string()),
tool_calls: None,
tool_call_id: None,
}];
for (sender_name, content, is_ai) in messages {
if *is_ai {
history.push(ChatMessage {
role: "assistant".to_string(),
content: Some(content.clone()),
tool_calls: None,
tool_call_id: None,
});
} else {
history.push(ChatMessage {
role: "user".to_string(),
content: Some(format!("[{}]: {}", sender_name, content)),
tool_calls: None,
tool_call_id: None,
});
}
}
history
}