groupchat/server/src/services/openrouter.rs

459 lines
15 KiB
Rust

use futures::StreamExt;
use serde::{Deserialize, Serialize};
const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
// ── Request types ──
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
}
/// Content can be a plain text string or multimodal (text + image) parts.
/// Serializes to a JSON string or a JSON array, matching the OpenAI/OpenRouter API.
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(untagged)]
pub enum Content {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlData },
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ImageUrlData {
pub url: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatMessage {
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
// ── Tool definition types ──
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Tool {
pub r#type: String,
pub function: ToolFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
pub id: String,
pub r#type: String,
pub function: ToolCallFunction,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: String,
}
// ── Response types ──
#[derive(Debug, Deserialize)]
struct Usage {
prompt_tokens: Option<u32>,
completion_tokens: Option<u32>,
total_tokens: Option<u32>,
}
// ── Streaming response types ──
#[derive(Debug, Deserialize)]
struct StreamChunk {
choices: Option<Vec<StreamChoice>>,
model: Option<String>,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamChoice {
delta: Option<StreamDelta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StreamDelta {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<StreamToolCall>>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamToolCall {
index: Option<usize>,
id: Option<String>,
#[serde(default)]
r#type: Option<String>,
function: Option<StreamToolCallFunction>,
}
#[derive(Debug, Deserialize)]
struct StreamToolCallFunction {
name: Option<String>,
arguments: Option<String>,
}
/// Events emitted during a streaming completion.
#[derive(Debug)]
pub enum StreamEvent {
/// A chunk of content text.
Delta(String),
/// AI wants to call tools (streaming accumulated the full tool_calls).
ToolCalls(ChatMessage, CompletionStats),
/// Stream finished with final stats.
Done(CompletionStats),
/// An error occurred.
Error(String),
}
/// Stats returned alongside an AI completion.
#[derive(Debug, Clone, Serialize)]
pub struct CompletionStats {
pub model: String,
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
pub response_ms: u64,
}
/// Build the tool definitions for web_search and web_fetch.
pub fn build_tools() -> Vec<Tool> {
vec![
Tool {
r#type: "function".into(),
function: ToolFunction {
name: "web_search".into(),
description: "Search the web for current information. Use this when users ask about recent events, need factual data you're unsure about, or want up-to-date information.".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query"
},
"count": {
"type": "integer",
"description": "Number of results (1-10, default 5)"
}
},
"required": ["query"]
}),
},
},
Tool {
r#type: "function".into(),
function: ToolFunction {
name: "web_fetch".into(),
description: "Fetch and read the content of a web page. Use this to read articles, documentation, or any URL shared by users or found in search results.".into(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch"
}
},
"required": ["url"]
}),
},
},
]
}
/// Send a streaming chat completion request to OpenRouter.
/// Sends events via the returned mpsc receiver:
/// - `StreamEvent::Delta(text)` for each content chunk
/// - `StreamEvent::ToolCalls(msg, stats)` if the AI wants to call tools
/// - `StreamEvent::Done(stats)` when the stream finishes with a text response
/// - `StreamEvent::Error(msg)` on error
pub async fn chat_completion_stream(
history: Vec<ChatMessage>,
model_id: &str,
api_key: &str,
tools: Option<Vec<Tool>>,
) -> tokio::sync::mpsc::Receiver<StreamEvent> {
let (tx, rx) = tokio::sync::mpsc::channel::<StreamEvent>(256);
let model_id = model_id.to_string();
let api_key = api_key.to_string();
tokio::spawn(async move {
let client = reqwest::Client::new();
let request_body = ChatRequest {
model: model_id.clone(),
messages: history,
max_tokens: Some(2048),
tools,
stream: Some(true),
};
let start = std::time::Instant::now();
let response = match client
.post(OPENROUTER_API_URL)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.header("HTTP-Referer", "http://localhost:3001")
.header("X-Title", "GroupChat")
.json(&request_body)
.send()
.await
{
Ok(r) => r,
Err(e) => {
let _ = tx
.send(StreamEvent::Error(format!("Request failed: {}", e)))
.await;
return;
}
};
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let _ = tx
.send(StreamEvent::Error(format!(
"OpenRouter error {}: {}",
status, body
)))
.await;
return;
}
// Read SSE stream
let mut byte_stream = response.bytes_stream();
let mut buffer = String::new();
let mut full_content = String::new();
let mut model_name = model_id.clone();
// Accumulators for streamed tool calls
let mut tool_call_accum: Vec<ToolCall> = Vec::new();
let mut has_tool_calls = false;
// Usage from the final chunk (some providers include it)
let mut final_usage: Option<Usage> = None;
while let Some(chunk_result) = byte_stream.next().await {
let bytes = match chunk_result {
Ok(b) => b,
Err(e) => {
let _ = tx
.send(StreamEvent::Error(format!("Stream error: {}", e)))
.await;
return;
}
};
buffer.push_str(&String::from_utf8_lossy(&bytes));
// Process complete SSE lines
while let Some(line_end) = buffer.find('\n') {
let line = buffer[..line_end].trim().to_string();
buffer = buffer[line_end + 1..].to_string();
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
let data = data.trim();
if data == "[DONE]" {
// Stream finished
continue;
}
let chunk: StreamChunk = match serde_json::from_str(data) {
Ok(c) => c,
Err(_) => continue,
};
if let Some(m) = &chunk.model {
model_name = m.clone();
}
if let Some(u) = chunk.usage {
final_usage = Some(u);
}
if let Some(choices) = &chunk.choices {
if let Some(choice) = choices.first() {
if let Some(delta) = &choice.delta {
// Handle content delta
if let Some(content) = &delta.content {
if !content.is_empty() {
full_content.push_str(content);
let _ = tx.send(StreamEvent::Delta(content.clone())).await;
}
}
// Handle tool call deltas (accumulate incrementally)
if let Some(tcs) = &delta.tool_calls {
has_tool_calls = true;
for tc in tcs {
let idx = tc.index.unwrap_or(0);
// Ensure we have enough slots
while tool_call_accum.len() <= idx {
tool_call_accum.push(ToolCall {
id: String::new(),
r#type: "function".to_string(),
function: ToolCallFunction {
name: String::new(),
arguments: String::new(),
},
});
}
if let Some(id) = &tc.id {
tool_call_accum[idx].id = id.clone();
}
if let Some(func) = &tc.function {
if let Some(name) = &func.name {
tool_call_accum[idx].function.name.push_str(name);
}
if let Some(args) = &func.arguments {
tool_call_accum[idx]
.function
.arguments
.push_str(args);
}
}
}
}
}
}
}
}
}
}
let elapsed_ms = start.elapsed().as_millis() as u64;
let (prompt_tokens, completion_tokens, total_tokens) = match &final_usage {
Some(u) => (
u.prompt_tokens.unwrap_or(0),
u.completion_tokens.unwrap_or(0),
u.total_tokens.unwrap_or(0),
),
None => (0, 0, 0),
};
let stats = CompletionStats {
model: model_name,
prompt_tokens,
completion_tokens,
total_tokens,
response_ms: elapsed_ms,
};
if has_tool_calls && !tool_call_accum.is_empty() {
// AI requested tool calls
let assistant_msg = ChatMessage {
role: "assistant".into(),
content: if full_content.is_empty() {
None
} else {
Some(Content::Text(full_content))
},
tool_calls: Some(tool_call_accum),
tool_call_id: None,
};
let _ = tx.send(StreamEvent::ToolCalls(assistant_msg, stats)).await;
} else {
// Normal text response completed
let _ = tx.send(StreamEvent::Done(stats)).await;
}
});
rx
}
/// Build the message history for OpenRouter from stored messages.
/// Includes the system prompt as the first message.
/// Messages with image data URLs will be sent as multimodal content.
pub fn build_chat_history(
system_prompt: &str,
messages: &[(String, String, bool, Option<String>)], // (sender_name, content, is_ai, image_data_url)
) -> Vec<ChatMessage> {
let mut history = vec![ChatMessage {
role: "system".to_string(),
content: Some(Content::Text(system_prompt.to_string())),
tool_calls: None,
tool_call_id: None,
}];
for (sender_name, content, is_ai, image_data_url) in messages {
if *is_ai {
history.push(ChatMessage {
role: "assistant".to_string(),
content: Some(Content::Text(content.clone())),
tool_calls: None,
tool_call_id: None,
});
} else {
let text = if content.is_empty() {
format!("[{}] shared an image:", sender_name)
} else {
format!("[{}]: {}", sender_name, content)
};
let msg_content = if let Some(data_url) = image_data_url {
Content::Parts(vec![
ContentPart::Text { text },
ContentPart::ImageUrl {
image_url: ImageUrlData {
url: data_url.clone(),
},
},
])
} else {
Content::Text(text)
};
history.push(ChatMessage {
role: "user".to_string(),
content: Some(msg_content),
tool_calls: None,
tool_call_id: None,
});
}
}
history
}