Replaces batch AI responses with real-time SSE streaming from OpenRouter. Tokens are buffered client-side and drained via requestAnimationFrame for a smooth typing effect instead of choppy chunk dumps. Backend: - Rewrite openrouter service for SSE streaming with incremental tool call accumulation - Add AiStreamChunk/AiStreamEnd WebSocket event types - Stream content deltas to clients during all tool call rounds - Increase broadcast channel capacity (256 -> 4096) and handle Lagged errors gracefully Frontend: - Add StreamBuffer utility with adaptive rAF-based character draining - Show streaming message-bubble with blinking cursor during generation - Clean up buffer on room switch and final message replacement Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
400 lines
13 KiB
Rust
400 lines
13 KiB
Rust
use futures::StreamExt;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
|
|
|
|
// ── Request types ──
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ChatRequest {
|
|
model: String,
|
|
messages: Vec<ChatMessage>,
|
|
max_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<Tool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
stream: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ChatMessage {
|
|
pub role: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub content: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
// ── Tool definition types ──
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct Tool {
|
|
pub r#type: String,
|
|
pub function: ToolFunction,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ToolFunction {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
pub r#type: String,
|
|
pub function: ToolCallFunction,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ToolCallFunction {
|
|
pub name: String,
|
|
pub arguments: String,
|
|
}
|
|
|
|
// ── Response types ──
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct Usage {
|
|
prompt_tokens: Option<u32>,
|
|
completion_tokens: Option<u32>,
|
|
total_tokens: Option<u32>,
|
|
}
|
|
|
|
// ── Streaming response types ──
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct StreamChunk {
|
|
choices: Option<Vec<StreamChoice>>,
|
|
model: Option<String>,
|
|
usage: Option<Usage>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[allow(dead_code)]
|
|
struct StreamChoice {
|
|
delta: Option<StreamDelta>,
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct StreamDelta {
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<StreamToolCall>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[allow(dead_code)]
|
|
struct StreamToolCall {
|
|
index: Option<usize>,
|
|
id: Option<String>,
|
|
#[serde(default)]
|
|
r#type: Option<String>,
|
|
function: Option<StreamToolCallFunction>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct StreamToolCallFunction {
|
|
name: Option<String>,
|
|
arguments: Option<String>,
|
|
}
|
|
|
|
/// Events emitted during a streaming completion.
|
|
#[derive(Debug)]
|
|
pub enum StreamEvent {
|
|
/// A chunk of content text.
|
|
Delta(String),
|
|
/// AI wants to call tools (streaming accumulated the full tool_calls).
|
|
ToolCalls(ChatMessage, CompletionStats),
|
|
/// Stream finished with final stats.
|
|
Done(CompletionStats),
|
|
/// An error occurred.
|
|
Error(String),
|
|
}
|
|
|
|
/// Stats returned alongside an AI completion.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct CompletionStats {
|
|
pub model: String,
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
pub response_ms: u64,
|
|
}
|
|
|
|
/// Build the tool definitions for brave_search and web_fetch.
|
|
pub fn build_tools() -> Vec<Tool> {
|
|
vec![
|
|
Tool {
|
|
r#type: "function".into(),
|
|
function: ToolFunction {
|
|
name: "brave_search".into(),
|
|
description: "Search the web for current information. Use this when users ask about recent events, need factual data you're unsure about, or want up-to-date information.".into(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The search query"
|
|
},
|
|
"count": {
|
|
"type": "integer",
|
|
"description": "Number of results (1-10, default 5)"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}),
|
|
},
|
|
},
|
|
Tool {
|
|
r#type: "function".into(),
|
|
function: ToolFunction {
|
|
name: "web_fetch".into(),
|
|
description: "Fetch and read the content of a web page. Use this to read articles, documentation, or any URL shared by users or found in search results.".into(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"url": {
|
|
"type": "string",
|
|
"description": "The URL to fetch"
|
|
}
|
|
},
|
|
"required": ["url"]
|
|
}),
|
|
},
|
|
},
|
|
]
|
|
}
|
|
|
|
/// Send a streaming chat completion request to OpenRouter.
|
|
/// Sends events via the returned mpsc receiver:
|
|
/// - `StreamEvent::Delta(text)` for each content chunk
|
|
/// - `StreamEvent::ToolCalls(msg, stats)` if the AI wants to call tools
|
|
/// - `StreamEvent::Done(stats)` when the stream finishes with a text response
|
|
/// - `StreamEvent::Error(msg)` on error
|
|
pub async fn chat_completion_stream(
|
|
history: Vec<ChatMessage>,
|
|
model_id: &str,
|
|
api_key: &str,
|
|
tools: Option<Vec<Tool>>,
|
|
) -> tokio::sync::mpsc::Receiver<StreamEvent> {
|
|
let (tx, rx) = tokio::sync::mpsc::channel::<StreamEvent>(256);
|
|
|
|
let model_id = model_id.to_string();
|
|
let api_key = api_key.to_string();
|
|
|
|
tokio::spawn(async move {
|
|
let client = reqwest::Client::new();
|
|
|
|
let request_body = ChatRequest {
|
|
model: model_id.clone(),
|
|
messages: history,
|
|
max_tokens: Some(2048),
|
|
tools,
|
|
stream: Some(true),
|
|
};
|
|
|
|
let start = std::time::Instant::now();
|
|
|
|
let response = match client
|
|
.post(OPENROUTER_API_URL)
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.header("Content-Type", "application/json")
|
|
.header("HTTP-Referer", "http://localhost:3001")
|
|
.header("X-Title", "GroupChat")
|
|
.json(&request_body)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
let _ = tx.send(StreamEvent::Error(format!("Request failed: {}", e))).await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
let _ = tx.send(StreamEvent::Error(format!("OpenRouter error {}: {}", status, body))).await;
|
|
return;
|
|
}
|
|
|
|
// Read SSE stream
|
|
let mut byte_stream = response.bytes_stream();
|
|
let mut buffer = String::new();
|
|
let mut full_content = String::new();
|
|
let mut model_name = model_id.clone();
|
|
|
|
// Accumulators for streamed tool calls
|
|
let mut tool_call_accum: Vec<ToolCall> = Vec::new();
|
|
let mut has_tool_calls = false;
|
|
|
|
// Usage from the final chunk (some providers include it)
|
|
let mut final_usage: Option<Usage> = None;
|
|
|
|
while let Some(chunk_result) = byte_stream.next().await {
|
|
let bytes = match chunk_result {
|
|
Ok(b) => b,
|
|
Err(e) => {
|
|
let _ = tx.send(StreamEvent::Error(format!("Stream error: {}", e))).await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
buffer.push_str(&String::from_utf8_lossy(&bytes));
|
|
|
|
// Process complete SSE lines
|
|
while let Some(line_end) = buffer.find('\n') {
|
|
let line = buffer[..line_end].trim().to_string();
|
|
buffer = buffer[line_end + 1..].to_string();
|
|
|
|
if line.is_empty() || line.starts_with(':') {
|
|
continue;
|
|
}
|
|
|
|
if let Some(data) = line.strip_prefix("data: ") {
|
|
let data = data.trim();
|
|
|
|
if data == "[DONE]" {
|
|
// Stream finished
|
|
continue;
|
|
}
|
|
|
|
let chunk: StreamChunk = match serde_json::from_str(data) {
|
|
Ok(c) => c,
|
|
Err(_) => continue,
|
|
};
|
|
|
|
if let Some(m) = &chunk.model {
|
|
model_name = m.clone();
|
|
}
|
|
|
|
if let Some(u) = chunk.usage {
|
|
final_usage = Some(u);
|
|
}
|
|
|
|
if let Some(choices) = &chunk.choices {
|
|
if let Some(choice) = choices.first() {
|
|
if let Some(delta) = &choice.delta {
|
|
// Handle content delta
|
|
if let Some(content) = &delta.content {
|
|
if !content.is_empty() {
|
|
full_content.push_str(content);
|
|
let _ = tx.send(StreamEvent::Delta(content.clone())).await;
|
|
}
|
|
}
|
|
|
|
// Handle tool call deltas (accumulate incrementally)
|
|
if let Some(tcs) = &delta.tool_calls {
|
|
has_tool_calls = true;
|
|
for tc in tcs {
|
|
let idx = tc.index.unwrap_or(0);
|
|
|
|
// Ensure we have enough slots
|
|
while tool_call_accum.len() <= idx {
|
|
tool_call_accum.push(ToolCall {
|
|
id: String::new(),
|
|
r#type: "function".to_string(),
|
|
function: ToolCallFunction {
|
|
name: String::new(),
|
|
arguments: String::new(),
|
|
},
|
|
});
|
|
}
|
|
|
|
if let Some(id) = &tc.id {
|
|
tool_call_accum[idx].id = id.clone();
|
|
}
|
|
if let Some(func) = &tc.function {
|
|
if let Some(name) = &func.name {
|
|
tool_call_accum[idx].function.name.push_str(name);
|
|
}
|
|
if let Some(args) = &func.arguments {
|
|
tool_call_accum[idx].function.arguments.push_str(args);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let elapsed_ms = start.elapsed().as_millis() as u64;
|
|
|
|
let (prompt_tokens, completion_tokens, total_tokens) = match &final_usage {
|
|
Some(u) => (
|
|
u.prompt_tokens.unwrap_or(0),
|
|
u.completion_tokens.unwrap_or(0),
|
|
u.total_tokens.unwrap_or(0),
|
|
),
|
|
None => (0, 0, 0),
|
|
};
|
|
|
|
let stats = CompletionStats {
|
|
model: model_name,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
response_ms: elapsed_ms,
|
|
};
|
|
|
|
if has_tool_calls && !tool_call_accum.is_empty() {
|
|
// AI requested tool calls
|
|
let assistant_msg = ChatMessage {
|
|
role: "assistant".into(),
|
|
content: if full_content.is_empty() { None } else { Some(full_content) },
|
|
tool_calls: Some(tool_call_accum),
|
|
tool_call_id: None,
|
|
};
|
|
let _ = tx.send(StreamEvent::ToolCalls(assistant_msg, stats)).await;
|
|
} else {
|
|
// Normal text response completed
|
|
let _ = tx.send(StreamEvent::Done(stats)).await;
|
|
}
|
|
});
|
|
|
|
rx
|
|
}
|
|
|
|
/// Build the message history for OpenRouter from stored messages.
|
|
/// Includes the system prompt as the first message.
|
|
pub fn build_chat_history(
|
|
system_prompt: &str,
|
|
messages: &[(String, String, bool)], // (sender_name, content, is_ai)
|
|
) -> Vec<ChatMessage> {
|
|
let mut history = vec![ChatMessage {
|
|
role: "system".to_string(),
|
|
content: Some(system_prompt.to_string()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
}];
|
|
|
|
for (sender_name, content, is_ai) in messages {
|
|
if *is_ai {
|
|
history.push(ChatMessage {
|
|
role: "assistant".to_string(),
|
|
content: Some(content.clone()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
});
|
|
} else {
|
|
history.push(ChatMessage {
|
|
role: "user".to_string(),
|
|
content: Some(format!("[{}]: {}", sender_name, content)),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
});
|
|
}
|
|
}
|
|
|
|
history
|
|
}
|