feat: add streaming AI responses with smooth token-by-token rendering
Replaces batch AI responses with real-time SSE streaming from OpenRouter. Tokens are buffered client-side and drained via requestAnimationFrame for a smooth typing effect instead of choppy chunk dumps. Backend: - Rewrite openrouter service for SSE streaming with incremental tool call accumulation - Add AiStreamChunk/AiStreamEnd WebSocket event types - Stream content deltas to clients during all tool call rounds - Increase broadcast channel capacity (256 -> 4096) and handle Lagged errors gracefully Frontend: - Add StreamBuffer utility with adaptive rAF-based character draining - Show streaming message-bubble with blinking cursor during generation - Clean up buffer on room switch and final message replacement Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
01258fa958
commit
4a002c85d4
@ -26,6 +26,7 @@
|
|||||||
ai-typing={state.aiTyping}
|
ai-typing={state.aiTyping}
|
||||||
ai-tool-status={state.aiToolStatus}
|
ai-tool-status={state.aiToolStatus}
|
||||||
typing-users={state.typingUsers}
|
typing-users={state.typingUsers}
|
||||||
|
streaming-message={state.streamingMessage}
|
||||||
cb-send={sendMessage}
|
cb-send={sendMessage}
|
||||||
cb-invite={() => update({ showInviteModal: true })}
|
cb-invite={() => update({ showInviteModal: true })}
|
||||||
cb-delete-room={() => update({ showDeleteModal: true })}
|
cb-delete-room={() => update({ showDeleteModal: true })}
|
||||||
@ -142,6 +143,7 @@
|
|||||||
<script>
|
<script>
|
||||||
import { api, saveAuth, getUser, clearAuth, isAuthenticated } from '../services/api.js'
|
import { api, saveAuth, getUser, clearAuth, isAuthenticated } from '../services/api.js'
|
||||||
import { ws } from '../services/websocket.js'
|
import { ws } from '../services/websocket.js'
|
||||||
|
import { StreamBuffer } from '../services/stream-buffer.js'
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
state: {
|
state: {
|
||||||
@ -157,6 +159,7 @@
|
|||||||
showClearModal: false,
|
showClearModal: false,
|
||||||
aiTyping: false,
|
aiTyping: false,
|
||||||
aiToolStatus: null,
|
aiToolStatus: null,
|
||||||
|
streamingMessage: null,
|
||||||
typingUsers: [],
|
typingUsers: [],
|
||||||
},
|
},
|
||||||
|
|
||||||
@ -178,15 +181,76 @@
|
|||||||
|
|
||||||
ws.on('new_message', (msg) => {
|
ws.on('new_message', (msg) => {
|
||||||
if (msg.message.room_id === this.state.activeRoomId) {
|
if (msg.message.room_id === this.state.activeRoomId) {
|
||||||
|
// If we were streaming this message, cancel the buffer and remove placeholder
|
||||||
|
const isStreamReplacement = this.state.streamingMessage?.id === msg.message.id
|
||||||
|
if (isStreamReplacement && this.streamBuffer) {
|
||||||
|
this.streamBuffer.cancel()
|
||||||
|
this.streamBuffer = null
|
||||||
|
this._streamMsgId = null
|
||||||
|
this._streamContent = ''
|
||||||
|
}
|
||||||
this.update({
|
this.update({
|
||||||
messages: [...this.state.messages, msg.message],
|
messages: [...this.state.messages, msg.message],
|
||||||
aiTyping: false,
|
aiTyping: false,
|
||||||
aiToolStatus: null,
|
aiToolStatus: null,
|
||||||
|
streamingMessage: isStreamReplacement ? null : this.state.streamingMessage,
|
||||||
})
|
})
|
||||||
this.scrollToBottom()
|
this.scrollToBottom()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
ws.on('ai_stream_chunk', (msg) => {
|
||||||
|
if (msg.room_id === this.state.activeRoomId) {
|
||||||
|
if (!this.streamBuffer || this._streamMsgId !== msg.message_id) {
|
||||||
|
// First chunk for a new message — create buffer + streaming message
|
||||||
|
if (this.streamBuffer) this.streamBuffer.cancel()
|
||||||
|
this._streamMsgId = msg.message_id
|
||||||
|
this._streamContent = ''
|
||||||
|
|
||||||
|
this.streamBuffer = new StreamBuffer(
|
||||||
|
// onText: drip chars into the displayed message
|
||||||
|
(text) => {
|
||||||
|
this._streamContent += text
|
||||||
|
this.update({
|
||||||
|
streamingMessage: {
|
||||||
|
id: this._streamMsgId,
|
||||||
|
room_id: msg.room_id,
|
||||||
|
sender_id: 'ai-assistant',
|
||||||
|
sender_name: 'AI Assistant',
|
||||||
|
content: this._streamContent,
|
||||||
|
mentions: [],
|
||||||
|
is_ai: true,
|
||||||
|
streaming: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
this.scrollToBottom()
|
||||||
|
},
|
||||||
|
// onDone: buffer fully drained after stream ended
|
||||||
|
() => {
|
||||||
|
if (this.state.streamingMessage?.id === this._streamMsgId) {
|
||||||
|
this.update({
|
||||||
|
streamingMessage: { ...this.state.streamingMessage, streaming: false },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
this.update({ aiTyping: false })
|
||||||
|
}
|
||||||
|
// Push every chunk into the buffer (it drains smoothly via rAF)
|
||||||
|
this.streamBuffer.push(msg.delta)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
ws.on('ai_stream_end', (msg) => {
|
||||||
|
if (msg.room_id === this.state.activeRoomId) {
|
||||||
|
// Tell buffer to flush remaining text, then signal done
|
||||||
|
if (this.streamBuffer && this._streamMsgId === msg.message_id) {
|
||||||
|
this.streamBuffer.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
ws.on('ai_typing', (msg) => {
|
ws.on('ai_typing', (msg) => {
|
||||||
if (msg.room_id === this.state.activeRoomId) {
|
if (msg.room_id === this.state.activeRoomId) {
|
||||||
this.update({ aiTyping: true })
|
this.update({ aiTyping: true })
|
||||||
@ -259,6 +323,13 @@
|
|||||||
|
|
||||||
async selectRoom(roomId) {
|
async selectRoom(roomId) {
|
||||||
try {
|
try {
|
||||||
|
// Cancel any active stream buffer when switching rooms
|
||||||
|
if (this.streamBuffer) {
|
||||||
|
this.streamBuffer.cancel()
|
||||||
|
this.streamBuffer = null
|
||||||
|
this._streamMsgId = null
|
||||||
|
this._streamContent = ''
|
||||||
|
}
|
||||||
const [room, messages] = await Promise.all([
|
const [room, messages] = await Promise.all([
|
||||||
api.getRoom(roomId),
|
api.getRoom(roomId),
|
||||||
api.getMessages(roomId),
|
api.getMessages(roomId),
|
||||||
@ -269,6 +340,7 @@
|
|||||||
messages,
|
messages,
|
||||||
aiTyping: false,
|
aiTyping: false,
|
||||||
aiToolStatus: null,
|
aiToolStatus: null,
|
||||||
|
streamingMessage: null,
|
||||||
typingUsers: [],
|
typingUsers: [],
|
||||||
})
|
})
|
||||||
ws.joinRoom(roomId)
|
ws.joinRoom(roomId)
|
||||||
|
|||||||
@ -40,7 +40,17 @@
|
|||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div if={props.aiTyping} class="typing-indicator ai-typing">
|
<!-- Streaming AI message (live content) -->
|
||||||
|
<div if={props.streamingMessage} key="streaming">
|
||||||
|
<message-bubble
|
||||||
|
message={props.streamingMessage}
|
||||||
|
is-own={false}
|
||||||
|
is-streaming={true}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- AI typing indicator (only when NOT streaming content) -->
|
||||||
|
<div if={props.aiTyping && !props.streamingMessage} class="typing-indicator ai-typing">
|
||||||
<div class="typing-avatar ai-avatar">AI</div>
|
<div class="typing-avatar ai-avatar">AI</div>
|
||||||
<template if={props.aiToolStatus}>
|
<template if={props.aiToolStatus}>
|
||||||
<span class="tool-status-text">
|
<span class="tool-status-text">
|
||||||
|
|||||||
@ -25,7 +25,8 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="message-content markdown-content"></div>
|
<div if={props.isStreaming} class="message-content streaming-content">{props.message?.content}<span class="streaming-cursor">▌</span></div>
|
||||||
|
<div if={!props.isStreaming} class="message-content markdown-content"></div>
|
||||||
<div if={props.message?.is_ai && props.message?.ai_meta} class="ai-stats-bar">
|
<div if={props.message?.is_ai && props.message?.ai_meta} class="ai-stats-bar">
|
||||||
<button class="ai-stat-btn" onclick={copyFullMessage} title="Copy response">
|
<button class="ai-stat-btn" onclick={copyFullMessage} title="Copy response">
|
||||||
<svg width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
<svg width="13" height="13" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
|
||||||
@ -280,6 +281,22 @@
|
|||||||
color: var(--accent);
|
color: var(--accent);
|
||||||
font-weight: 500;
|
font-weight: 500;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.streaming-content {
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-wrap: break-word;
|
||||||
|
}
|
||||||
|
|
||||||
|
.streaming-cursor {
|
||||||
|
animation: cursor-blink 0.8s step-end infinite;
|
||||||
|
color: var(--accent);
|
||||||
|
font-weight: 300;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes cursor-blink {
|
||||||
|
0%, 100% { opacity: 1; }
|
||||||
|
50% { opacity: 0; }
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
@ -295,7 +312,8 @@
|
|||||||
},
|
},
|
||||||
|
|
||||||
renderContent() {
|
renderContent() {
|
||||||
const el = this.$('.message-content')
|
if (this.props.isStreaming) return // Don't markdown-render while streaming
|
||||||
|
const el = this.$('.message-content.markdown-content')
|
||||||
if (el && this.props.message?.content) {
|
if (el && this.props.message?.content) {
|
||||||
el.innerHTML = renderMarkdown(this.props.message.content)
|
el.innerHTML = renderMarkdown(this.props.message.content)
|
||||||
// Inject copy buttons into code blocks
|
// Inject copy buttons into code blocks
|
||||||
|
|||||||
92
client/src/services/stream-buffer.js
Normal file
92
client/src/services/stream-buffer.js
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
/**
|
||||||
|
* StreamBuffer — smooth token-by-token rendering for LLM streams.
|
||||||
|
*
|
||||||
|
* Network delivers tokens in bursts (5-10 at once, then a pause).
|
||||||
|
* This buffer queues incoming text and drains it at a steady rate
|
||||||
|
* via requestAnimationFrame, creating the smooth "typing" effect
|
||||||
|
* seen in ChatGPT / Claude.ai.
|
||||||
|
*/
|
||||||
|
export class StreamBuffer {
|
||||||
|
/**
|
||||||
|
* @param {(text: string) => void} onText — called each frame with chars to append
|
||||||
|
* @param {() => void} onDone — called when buffer fully drained after finish()
|
||||||
|
*/
|
||||||
|
constructor(onText, onDone) {
|
||||||
|
this.queue = ''
|
||||||
|
this.onText = onText
|
||||||
|
this.onDone = onDone
|
||||||
|
this.rafId = null
|
||||||
|
this.finished = false // stream has ended, flush remaining
|
||||||
|
this.baseSpeed = 3 // chars per frame at 60fps (~180 chars/sec)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Push new text from a stream chunk into the buffer. */
|
||||||
|
push(text) {
|
||||||
|
this.queue += text
|
||||||
|
if (!this.rafId) {
|
||||||
|
this.startDrain()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Signal that the stream is complete — flush remaining text quickly. */
|
||||||
|
finish() {
|
||||||
|
this.finished = true
|
||||||
|
if (!this.rafId && this.queue.length > 0) {
|
||||||
|
this.startDrain()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Cancel the buffer (e.g. room switch). */
|
||||||
|
cancel() {
|
||||||
|
if (this.rafId) {
|
||||||
|
cancelAnimationFrame(this.rafId)
|
||||||
|
this.rafId = null
|
||||||
|
}
|
||||||
|
this.queue = ''
|
||||||
|
this.finished = false
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @private */
|
||||||
|
startDrain() {
|
||||||
|
this.rafId = requestAnimationFrame(() => this.drain())
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @private */
|
||||||
|
drain() {
|
||||||
|
if (this.queue.length === 0) {
|
||||||
|
this.rafId = null
|
||||||
|
if (this.finished) {
|
||||||
|
this.onDone()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adaptive speed:
|
||||||
|
// - Base: ~3 chars/frame (smooth typing feel)
|
||||||
|
// - If buffer > 50 chars: speed up to avoid falling behind
|
||||||
|
// - If stream ended: flush fast (10+ chars/frame)
|
||||||
|
let chars = this.baseSpeed
|
||||||
|
|
||||||
|
if (this.queue.length > 200) {
|
||||||
|
// Very behind — catch up aggressively
|
||||||
|
chars = Math.ceil(this.queue.length / 10)
|
||||||
|
} else if (this.queue.length > 50) {
|
||||||
|
// Moderately behind — speed up proportionally
|
||||||
|
chars = Math.ceil(this.queue.length / 15)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.finished) {
|
||||||
|
// Stream done, flush remaining smoothly but quickly
|
||||||
|
chars = Math.max(chars, Math.ceil(this.queue.length / 8))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't break in the middle of a multi-byte character or word
|
||||||
|
// (slice is safe for most LLM output which is ASCII/simple unicode)
|
||||||
|
const chunk = this.queue.slice(0, chars)
|
||||||
|
this.queue = this.queue.slice(chars)
|
||||||
|
|
||||||
|
this.onText(chunk)
|
||||||
|
|
||||||
|
this.rafId = requestAnimationFrame(() => this.drain())
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -52,11 +52,23 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
|||||||
|
|
||||||
// Task: forward broadcast events to this client
|
// Task: forward broadcast events to this client
|
||||||
let mut send_task = tokio::spawn(async move {
|
let mut send_task = tokio::spawn(async move {
|
||||||
while let Ok(event) = broadcast_rx.recv().await {
|
loop {
|
||||||
let rooms = rooms_clone.lock().await;
|
match broadcast_rx.recv().await {
|
||||||
if rooms.contains(&event.room_id) {
|
Ok(event) => {
|
||||||
let msg = serde_json::to_string(&event.message).unwrap();
|
let rooms = rooms_clone.lock().await;
|
||||||
if ws_tx.send(Message::Text(msg.into())).await.is_err() {
|
if rooms.contains(&event.room_id) {
|
||||||
|
let msg = serde_json::to_string(&event.message).unwrap();
|
||||||
|
if ws_tx.send(Message::Text(msg.into())).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||||
|
tracing::warn!("WS subscriber lagged, skipped {} messages", n);
|
||||||
|
// Continue receiving — don't drop the connection
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -224,7 +236,10 @@ async fn handle_send_message(
|
|||||||
// Build tools for AI
|
// Build tools for AI
|
||||||
let tools = openrouter::build_tools();
|
let tools = openrouter::build_tools();
|
||||||
|
|
||||||
// Call OpenRouter with tool loop
|
// Pre-generate AI message ID so we can reference it in stream chunks
|
||||||
|
let ai_msg_id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
// Call OpenRouter with tool loop — uses streaming for all rounds
|
||||||
let mut total_prompt_tokens: u32 = 0;
|
let mut total_prompt_tokens: u32 = 0;
|
||||||
let mut total_completion_tokens: u32 = 0;
|
let mut total_completion_tokens: u32 = 0;
|
||||||
let mut total_response_ms: u64 = 0;
|
let mut total_response_ms: u64 = 0;
|
||||||
@ -233,8 +248,8 @@ async fn handle_send_message(
|
|||||||
let mut had_error = false;
|
let mut had_error = false;
|
||||||
let mut collected_tool_results: Vec<crate::models::ToolResult> = vec![];
|
let mut collected_tool_results: Vec<crate::models::ToolResult> = vec![];
|
||||||
|
|
||||||
for round in 0..MAX_TOOL_ROUNDS {
|
'tool_loop: for round in 0..MAX_TOOL_ROUNDS {
|
||||||
let result = openrouter::chat_completion(
|
let mut stream_rx = openrouter::chat_completion_stream(
|
||||||
chat_history.clone(),
|
chat_history.clone(),
|
||||||
&model_id,
|
&model_id,
|
||||||
&state.openrouter_key,
|
&state.openrouter_key,
|
||||||
@ -242,82 +257,92 @@ async fn handle_send_message(
|
|||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
match result {
|
while let Some(event) = stream_rx.recv().await {
|
||||||
Ok(openrouter::ChatCompletionResult::Response(text, stats)) => {
|
match event {
|
||||||
// Final text response — done!
|
openrouter::StreamEvent::Delta(text) => {
|
||||||
total_prompt_tokens += stats.prompt_tokens;
|
// Broadcast each content chunk to clients
|
||||||
total_completion_tokens += stats.completion_tokens;
|
|
||||||
total_response_ms += stats.response_ms;
|
|
||||||
final_model = stats.model;
|
|
||||||
ai_response = text;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(openrouter::ChatCompletionResult::ToolCalls(assistant_msg, stats)) => {
|
|
||||||
total_prompt_tokens += stats.prompt_tokens;
|
|
||||||
total_completion_tokens += stats.completion_tokens;
|
|
||||||
total_response_ms += stats.response_ms;
|
|
||||||
final_model = stats.model.clone();
|
|
||||||
|
|
||||||
tracing::info!(
|
|
||||||
"AI requesting tool calls (round {}): {:?}",
|
|
||||||
round + 1,
|
|
||||||
assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
|
|
||||||
);
|
|
||||||
|
|
||||||
// Add the assistant's tool-call message to history
|
|
||||||
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
|
|
||||||
chat_history.push(assistant_msg);
|
|
||||||
|
|
||||||
// Execute each tool call and add results
|
|
||||||
for tool_call in &tool_calls {
|
|
||||||
// Extract tool input for display purposes
|
|
||||||
let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments);
|
|
||||||
|
|
||||||
// Broadcast real-time tool usage event
|
|
||||||
let _ = state.tx.send(BroadcastEvent {
|
let _ = state.tx.send(BroadcastEvent {
|
||||||
room_id: room_id.to_string(),
|
room_id: room_id.to_string(),
|
||||||
message: WsServerMessage::AiToolUsage {
|
message: WsServerMessage::AiStreamChunk {
|
||||||
room_id: room_id.to_string(),
|
room_id: room_id.to_string(),
|
||||||
tool_name: tool_call.function.name.clone(),
|
message_id: ai_msg_id.clone(),
|
||||||
status: "calling".to_string(),
|
delta: text.clone(),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
ai_response.push_str(&text);
|
||||||
let tool_result = execute_tool(
|
}
|
||||||
&tool_call.function.name,
|
openrouter::StreamEvent::ToolCalls(assistant_msg, stats) => {
|
||||||
&tool_call.function.arguments,
|
total_prompt_tokens += stats.prompt_tokens;
|
||||||
&state.brave_api_key,
|
total_completion_tokens += stats.completion_tokens;
|
||||||
)
|
total_response_ms += stats.response_ms;
|
||||||
.await;
|
final_model = stats.model.clone();
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
"Tool {} result: {} chars",
|
"AI requesting tool calls (round {}): {:?}",
|
||||||
tool_call.function.name,
|
round + 1,
|
||||||
tool_result.len()
|
assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
|
||||||
);
|
);
|
||||||
|
|
||||||
// Collect tool result for inclusion in final message
|
// Add the assistant's tool-call message to history
|
||||||
collected_tool_results.push(crate::models::ToolResult {
|
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
|
||||||
tool: tool_call.function.name.clone(),
|
chat_history.push(assistant_msg);
|
||||||
input: tool_input,
|
|
||||||
result: tool_result.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add tool result to history
|
// Execute each tool call and add results
|
||||||
chat_history.push(openrouter::ChatMessage {
|
for tool_call in &tool_calls {
|
||||||
role: "tool".into(),
|
let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments);
|
||||||
content: Some(tool_result),
|
|
||||||
tool_calls: None,
|
// Broadcast real-time tool usage event
|
||||||
tool_call_id: Some(tool_call.id.clone()),
|
let _ = state.tx.send(BroadcastEvent {
|
||||||
});
|
room_id: room_id.to_string(),
|
||||||
|
message: WsServerMessage::AiToolUsage {
|
||||||
|
room_id: room_id.to_string(),
|
||||||
|
tool_name: tool_call.function.name.clone(),
|
||||||
|
status: "calling".to_string(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
let tool_result = execute_tool(
|
||||||
|
&tool_call.function.name,
|
||||||
|
&tool_call.function.arguments,
|
||||||
|
&state.brave_api_key,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
"Tool {} result: {} chars",
|
||||||
|
tool_call.function.name,
|
||||||
|
tool_result.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
collected_tool_results.push(crate::models::ToolResult {
|
||||||
|
tool: tool_call.function.name.clone(),
|
||||||
|
input: tool_input,
|
||||||
|
result: tool_result.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
chat_history.push(openrouter::ChatMessage {
|
||||||
|
role: "tool".into(),
|
||||||
|
content: Some(tool_result),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: Some(tool_call.id.clone()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Continue to next round (tool loop)
|
||||||
|
continue 'tool_loop;
|
||||||
|
}
|
||||||
|
openrouter::StreamEvent::Done(stats) => {
|
||||||
|
total_prompt_tokens += stats.prompt_tokens;
|
||||||
|
total_completion_tokens += stats.completion_tokens;
|
||||||
|
total_response_ms += stats.response_ms;
|
||||||
|
final_model = stats.model;
|
||||||
|
break 'tool_loop;
|
||||||
|
}
|
||||||
|
openrouter::StreamEvent::Error(e) => {
|
||||||
|
tracing::error!("OpenRouter stream error (round {}): {}", round + 1, e);
|
||||||
|
ai_response = format!("*Sorry, I encountered an error: {}*", e);
|
||||||
|
had_error = true;
|
||||||
|
break 'tool_loop;
|
||||||
}
|
}
|
||||||
// Loop continues — call OpenRouter again with tool results
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("OpenRouter error (round {}): {}", round + 1, e);
|
|
||||||
ai_response = format!("*Sorry, I encountered an error: {}*", e);
|
|
||||||
had_error = true;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -327,6 +352,15 @@ async fn handle_send_message(
|
|||||||
ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string();
|
ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Signal stream end so client can finalize rendering
|
||||||
|
let _ = state.tx.send(BroadcastEvent {
|
||||||
|
room_id: room_id.to_string(),
|
||||||
|
message: WsServerMessage::AiStreamEnd {
|
||||||
|
room_id: room_id.to_string(),
|
||||||
|
message_id: ai_msg_id.clone(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
let ai_meta = if !had_error {
|
let ai_meta = if !had_error {
|
||||||
Some(crate::models::AiMeta {
|
Some(crate::models::AiMeta {
|
||||||
model: final_model,
|
model: final_model,
|
||||||
@ -345,7 +379,6 @@ async fn handle_send_message(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Store AI response
|
// Store AI response
|
||||||
let ai_msg_id = Uuid::new_v4().to_string();
|
|
||||||
let ai_now = chrono::Utc::now().to_rfc3339();
|
let ai_now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
// Serialize ai_meta for database storage
|
// Serialize ai_meta for database storage
|
||||||
@ -364,7 +397,7 @@ async fn handle_send_message(
|
|||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Broadcast AI message
|
// Broadcast final AI message (includes full content + ai_meta)
|
||||||
let ai_payload = MessagePayload {
|
let ai_payload = MessagePayload {
|
||||||
id: ai_msg_id,
|
id: ai_msg_id,
|
||||||
room_id: room_id.to_string(),
|
room_id: room_id.to_string(),
|
||||||
|
|||||||
@ -72,7 +72,7 @@ async fn main() {
|
|||||||
|
|
||||||
tracing::info!("Database initialized");
|
tracing::info!("Database initialized");
|
||||||
|
|
||||||
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(256);
|
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096);
|
||||||
|
|
||||||
let state = Arc::new(AppState {
|
let state = Arc::new(AppState {
|
||||||
db,
|
db,
|
||||||
|
|||||||
@ -164,6 +164,17 @@ pub enum WsServerMessage {
|
|||||||
tool_name: String,
|
tool_name: String,
|
||||||
status: String,
|
status: String,
|
||||||
},
|
},
|
||||||
|
#[serde(rename = "ai_stream_chunk")]
|
||||||
|
AiStreamChunk {
|
||||||
|
room_id: String,
|
||||||
|
message_id: String,
|
||||||
|
delta: String,
|
||||||
|
},
|
||||||
|
#[serde(rename = "ai_stream_end")]
|
||||||
|
AiStreamEnd {
|
||||||
|
room_id: String,
|
||||||
|
message_id: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
use futures::StreamExt;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
|
const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
|
||||||
@ -11,6 +12,8 @@ struct ChatRequest {
|
|||||||
max_tokens: Option<u32>,
|
max_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
stream: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
@ -55,29 +58,62 @@ pub struct ToolCallFunction {
|
|||||||
// ── Response types ──
|
// ── Response types ──
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ChatResponse {
|
struct Usage {
|
||||||
choices: Vec<Choice>,
|
prompt_tokens: Option<u32>,
|
||||||
|
completion_tokens: Option<u32>,
|
||||||
|
total_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Streaming response types ──
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamChunk {
|
||||||
|
choices: Option<Vec<StreamChoice>>,
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
usage: Option<Usage>,
|
usage: Option<Usage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Choice {
|
#[allow(dead_code)]
|
||||||
message: ChoiceMessage,
|
struct StreamChoice {
|
||||||
|
delta: Option<StreamDelta>,
|
||||||
|
finish_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ChoiceMessage {
|
struct StreamDelta {
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
tool_calls: Option<Vec<StreamToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Usage {
|
#[allow(dead_code)]
|
||||||
prompt_tokens: Option<u32>,
|
struct StreamToolCall {
|
||||||
completion_tokens: Option<u32>,
|
index: Option<usize>,
|
||||||
total_tokens: Option<u32>,
|
id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
r#type: Option<String>,
|
||||||
|
function: Option<StreamToolCallFunction>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamToolCallFunction {
|
||||||
|
name: Option<String>,
|
||||||
|
arguments: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Events emitted during a streaming completion.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum StreamEvent {
|
||||||
|
/// A chunk of content text.
|
||||||
|
Delta(String),
|
||||||
|
/// AI wants to call tools (streaming accumulated the full tool_calls).
|
||||||
|
ToolCalls(ChatMessage, CompletionStats),
|
||||||
|
/// Stream finished with final stats.
|
||||||
|
Done(CompletionStats),
|
||||||
|
/// An error occurred.
|
||||||
|
Error(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Stats returned alongside an AI completion.
|
/// Stats returned alongside an AI completion.
|
||||||
@ -90,14 +126,6 @@ pub struct CompletionStats {
|
|||||||
pub response_ms: u64,
|
pub response_ms: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result from a chat completion — either a final text response or tool calls.
|
|
||||||
pub enum ChatCompletionResult {
|
|
||||||
/// AI responded with text content.
|
|
||||||
Response(String, CompletionStats),
|
|
||||||
/// AI wants to call tools. Contains the assistant message (with tool_calls) and stats.
|
|
||||||
ToolCalls(ChatMessage, CompletionStats),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build the tool definitions for brave_search and web_fetch.
|
/// Build the tool definitions for brave_search and web_fetch.
|
||||||
pub fn build_tools() -> Vec<Tool> {
|
pub fn build_tools() -> Vec<Tool> {
|
||||||
vec![
|
vec![
|
||||||
@ -142,85 +170,198 @@ pub fn build_tools() -> Vec<Tool> {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send a chat completion request to OpenRouter.
|
/// Send a streaming chat completion request to OpenRouter.
|
||||||
/// Returns either a text response or tool call requests.
|
/// Sends events via the returned mpsc receiver:
|
||||||
pub async fn chat_completion(
|
/// - `StreamEvent::Delta(text)` for each content chunk
|
||||||
|
/// - `StreamEvent::ToolCalls(msg, stats)` if the AI wants to call tools
|
||||||
|
/// - `StreamEvent::Done(stats)` when the stream finishes with a text response
|
||||||
|
/// - `StreamEvent::Error(msg)` on error
|
||||||
|
pub async fn chat_completion_stream(
|
||||||
history: Vec<ChatMessage>,
|
history: Vec<ChatMessage>,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
api_key: &str,
|
api_key: &str,
|
||||||
tools: Option<Vec<Tool>>,
|
tools: Option<Vec<Tool>>,
|
||||||
) -> Result<ChatCompletionResult, String> {
|
) -> tokio::sync::mpsc::Receiver<StreamEvent> {
|
||||||
let client = reqwest::Client::new();
|
let (tx, rx) = tokio::sync::mpsc::channel::<StreamEvent>(256);
|
||||||
|
|
||||||
let request_body = ChatRequest {
|
let model_id = model_id.to_string();
|
||||||
model: model_id.to_string(),
|
let api_key = api_key.to_string();
|
||||||
messages: history,
|
|
||||||
max_tokens: Some(2048),
|
|
||||||
tools,
|
|
||||||
};
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
tokio::spawn(async move {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
let response = client
|
let request_body = ChatRequest {
|
||||||
.post(OPENROUTER_API_URL)
|
model: model_id.clone(),
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
messages: history,
|
||||||
.header("Content-Type", "application/json")
|
max_tokens: Some(2048),
|
||||||
.header("HTTP-Referer", "http://localhost:3001")
|
tools,
|
||||||
.header("X-Title", "GroupChat")
|
stream: Some(true),
|
||||||
.json(&request_body)
|
};
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("OpenRouter request failed: {}", e))?;
|
|
||||||
|
|
||||||
let elapsed_ms = start.elapsed().as_millis() as u64;
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
if !response.status().is_success() {
|
let response = match client
|
||||||
let status = response.status();
|
.post(OPENROUTER_API_URL)
|
||||||
let body = response.text().await.unwrap_or_default();
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
return Err(format!("OpenRouter error {}: {}", status, body));
|
.header("Content-Type", "application/json")
|
||||||
}
|
.header("HTTP-Referer", "http://localhost:3001")
|
||||||
|
.header("X-Title", "GroupChat")
|
||||||
|
.json(&request_body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(StreamEvent::Error(format!("Request failed: {}", e))).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let chat_response: ChatResponse = response
|
if !response.status().is_success() {
|
||||||
.json()
|
let status = response.status();
|
||||||
.await
|
let body = response.text().await.unwrap_or_default();
|
||||||
.map_err(|e| format!("Failed to parse OpenRouter response: {}", e))?;
|
let _ = tx.send(StreamEvent::Error(format!("OpenRouter error {}: {}", status, body))).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let choice = chat_response
|
// Read SSE stream
|
||||||
.choices
|
let mut byte_stream = response.bytes_stream();
|
||||||
.first()
|
let mut buffer = String::new();
|
||||||
.ok_or_else(|| "No response from OpenRouter".to_string())?;
|
let mut full_content = String::new();
|
||||||
|
let mut model_name = model_id.clone();
|
||||||
|
|
||||||
let usage = chat_response.usage.unwrap_or(Usage {
|
// Accumulators for streamed tool calls
|
||||||
prompt_tokens: None,
|
let mut tool_call_accum: Vec<ToolCall> = Vec::new();
|
||||||
completion_tokens: None,
|
let mut has_tool_calls = false;
|
||||||
total_tokens: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
let stats = CompletionStats {
|
// Usage from the final chunk (some providers include it)
|
||||||
model: chat_response.model.unwrap_or_else(|| model_id.to_string()),
|
let mut final_usage: Option<Usage> = None;
|
||||||
prompt_tokens: usage.prompt_tokens.unwrap_or(0),
|
|
||||||
completion_tokens: usage.completion_tokens.unwrap_or(0),
|
|
||||||
total_tokens: usage.total_tokens.unwrap_or(0),
|
|
||||||
response_ms: elapsed_ms,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if the AI wants to call tools
|
while let Some(chunk_result) = byte_stream.next().await {
|
||||||
if let Some(tool_calls) = &choice.message.tool_calls {
|
let bytes = match chunk_result {
|
||||||
if !tool_calls.is_empty() {
|
Ok(b) => b,
|
||||||
// Return the assistant message with tool calls so it can be added to history
|
Err(e) => {
|
||||||
|
let _ = tx.send(StreamEvent::Error(format!("Stream error: {}", e))).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
buffer.push_str(&String::from_utf8_lossy(&bytes));
|
||||||
|
|
||||||
|
// Process complete SSE lines
|
||||||
|
while let Some(line_end) = buffer.find('\n') {
|
||||||
|
let line = buffer[..line_end].trim().to_string();
|
||||||
|
buffer = buffer[line_end + 1..].to_string();
|
||||||
|
|
||||||
|
if line.is_empty() || line.starts_with(':') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
let data = data.trim();
|
||||||
|
|
||||||
|
if data == "[DONE]" {
|
||||||
|
// Stream finished
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk: StreamChunk = match serde_json::from_str(data) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(m) = &chunk.model {
|
||||||
|
model_name = m.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(u) = chunk.usage {
|
||||||
|
final_usage = Some(u);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(choices) = &chunk.choices {
|
||||||
|
if let Some(choice) = choices.first() {
|
||||||
|
if let Some(delta) = &choice.delta {
|
||||||
|
// Handle content delta
|
||||||
|
if let Some(content) = &delta.content {
|
||||||
|
if !content.is_empty() {
|
||||||
|
full_content.push_str(content);
|
||||||
|
let _ = tx.send(StreamEvent::Delta(content.clone())).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle tool call deltas (accumulate incrementally)
|
||||||
|
if let Some(tcs) = &delta.tool_calls {
|
||||||
|
has_tool_calls = true;
|
||||||
|
for tc in tcs {
|
||||||
|
let idx = tc.index.unwrap_or(0);
|
||||||
|
|
||||||
|
// Ensure we have enough slots
|
||||||
|
while tool_call_accum.len() <= idx {
|
||||||
|
tool_call_accum.push(ToolCall {
|
||||||
|
id: String::new(),
|
||||||
|
r#type: "function".to_string(),
|
||||||
|
function: ToolCallFunction {
|
||||||
|
name: String::new(),
|
||||||
|
arguments: String::new(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(id) = &tc.id {
|
||||||
|
tool_call_accum[idx].id = id.clone();
|
||||||
|
}
|
||||||
|
if let Some(func) = &tc.function {
|
||||||
|
if let Some(name) = &func.name {
|
||||||
|
tool_call_accum[idx].function.name.push_str(name);
|
||||||
|
}
|
||||||
|
if let Some(args) = &func.arguments {
|
||||||
|
tool_call_accum[idx].function.arguments.push_str(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let elapsed_ms = start.elapsed().as_millis() as u64;
|
||||||
|
|
||||||
|
let (prompt_tokens, completion_tokens, total_tokens) = match &final_usage {
|
||||||
|
Some(u) => (
|
||||||
|
u.prompt_tokens.unwrap_or(0),
|
||||||
|
u.completion_tokens.unwrap_or(0),
|
||||||
|
u.total_tokens.unwrap_or(0),
|
||||||
|
),
|
||||||
|
None => (0, 0, 0),
|
||||||
|
};
|
||||||
|
|
||||||
|
let stats = CompletionStats {
|
||||||
|
model: model_name,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
total_tokens,
|
||||||
|
response_ms: elapsed_ms,
|
||||||
|
};
|
||||||
|
|
||||||
|
if has_tool_calls && !tool_call_accum.is_empty() {
|
||||||
|
// AI requested tool calls
|
||||||
let assistant_msg = ChatMessage {
|
let assistant_msg = ChatMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: choice.message.content.clone(),
|
content: if full_content.is_empty() { None } else { Some(full_content) },
|
||||||
tool_calls: Some(tool_calls.clone()),
|
tool_calls: Some(tool_call_accum),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
};
|
};
|
||||||
return Ok(ChatCompletionResult::ToolCalls(assistant_msg, stats));
|
let _ = tx.send(StreamEvent::ToolCalls(assistant_msg, stats)).await;
|
||||||
|
} else {
|
||||||
|
// Normal text response completed
|
||||||
|
let _ = tx.send(StreamEvent::Done(stats)).await;
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
|
|
||||||
// Regular text response
|
rx
|
||||||
let content = choice.message.content.clone().unwrap_or_default();
|
|
||||||
Ok(ChatCompletionResult::Response(content, stats))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the message history for OpenRouter from stored messages.
|
/// Build the message history for OpenRouter from stored messages.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user