use futures::StreamExt; use serde::{Deserialize, Serialize}; const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions"; // ── Request types ── #[derive(Debug, Serialize)] struct ChatRequest { model: String, messages: Vec, max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] stream: Option, } /// Content can be a plain text string or multimodal (text + image) parts. /// Serializes to a JSON string or a JSON array, matching the OpenAI/OpenRouter API. #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(untagged)] pub enum Content { Text(String), Parts(Vec), } #[derive(Debug, Serialize, Deserialize, Clone)] #[serde(tag = "type")] pub enum ContentPart { #[serde(rename = "text")] Text { text: String }, #[serde(rename = "image_url")] ImageUrl { image_url: ImageUrlData }, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ImageUrlData { pub url: String, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ChatMessage { pub role: String, #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, } // ── Tool definition types ── #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Tool { pub r#type: String, pub function: ToolFunction, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ToolFunction { pub name: String, pub description: String, pub parameters: serde_json::Value, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ToolCall { pub id: String, pub r#type: String, pub function: ToolCallFunction, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ToolCallFunction { pub name: String, pub arguments: String, } // ── Response types ── #[derive(Debug, Deserialize)] struct Usage { prompt_tokens: Option, completion_tokens: Option, total_tokens: Option, } // ── Streaming response types ── #[derive(Debug, Deserialize)] struct StreamChunk { choices: Option>, model: Option, usage: Option, } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct StreamChoice { delta: Option, finish_reason: Option, } #[derive(Debug, Deserialize)] struct StreamDelta { content: Option, #[serde(default)] tool_calls: Option>, } #[derive(Debug, Deserialize)] #[allow(dead_code)] struct StreamToolCall { index: Option, id: Option, #[serde(default)] r#type: Option, function: Option, } #[derive(Debug, Deserialize)] struct StreamToolCallFunction { name: Option, arguments: Option, } /// Events emitted during a streaming completion. #[derive(Debug)] pub enum StreamEvent { /// A chunk of content text. Delta(String), /// AI wants to call tools (streaming accumulated the full tool_calls). ToolCalls(ChatMessage, CompletionStats), /// Stream finished with final stats. Done(CompletionStats), /// An error occurred. Error(String), } /// Stats returned alongside an AI completion. #[derive(Debug, Clone, Serialize)] pub struct CompletionStats { pub model: String, pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, pub response_ms: u64, } /// Build the tool definitions for web_search and web_fetch. pub fn build_tools() -> Vec { vec![ Tool { r#type: "function".into(), function: ToolFunction { name: "web_search".into(), description: "Search the web for current information. Use this when users ask about recent events, need factual data you're unsure about, or want up-to-date information.".into(), parameters: serde_json::json!({ "type": "object", "properties": { "query": { "type": "string", "description": "The search query" }, "count": { "type": "integer", "description": "Number of results (1-10, default 5)" } }, "required": ["query"] }), }, }, Tool { r#type: "function".into(), function: ToolFunction { name: "web_fetch".into(), description: "Fetch and read the content of a web page. Use this to read articles, documentation, or any URL shared by users or found in search results.".into(), parameters: serde_json::json!({ "type": "object", "properties": { "url": { "type": "string", "description": "The URL to fetch" } }, "required": ["url"] }), }, }, ] } /// Send a streaming chat completion request to OpenRouter. /// Sends events via the returned mpsc receiver: /// - `StreamEvent::Delta(text)` for each content chunk /// - `StreamEvent::ToolCalls(msg, stats)` if the AI wants to call tools /// - `StreamEvent::Done(stats)` when the stream finishes with a text response /// - `StreamEvent::Error(msg)` on error pub async fn chat_completion_stream( history: Vec, model_id: &str, api_key: &str, tools: Option>, ) -> tokio::sync::mpsc::Receiver { let (tx, rx) = tokio::sync::mpsc::channel::(256); let model_id = model_id.to_string(); let api_key = api_key.to_string(); tokio::spawn(async move { let client = reqwest::Client::new(); let request_body = ChatRequest { model: model_id.clone(), messages: history, max_tokens: Some(2048), tools, stream: Some(true), }; let start = std::time::Instant::now(); let response = match client .post(OPENROUTER_API_URL) .header("Authorization", format!("Bearer {}", api_key)) .header("Content-Type", "application/json") .header("HTTP-Referer", "http://localhost:3001") .header("X-Title", "GroupChat") .json(&request_body) .send() .await { Ok(r) => r, Err(e) => { let _ = tx .send(StreamEvent::Error(format!("Request failed: {}", e))) .await; return; } }; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); let _ = tx .send(StreamEvent::Error(format!( "OpenRouter error {}: {}", status, body ))) .await; return; } // Read SSE stream let mut byte_stream = response.bytes_stream(); let mut buffer = String::new(); let mut full_content = String::new(); let mut model_name = model_id.clone(); // Accumulators for streamed tool calls let mut tool_call_accum: Vec = Vec::new(); let mut has_tool_calls = false; // Usage from the final chunk (some providers include it) let mut final_usage: Option = None; while let Some(chunk_result) = byte_stream.next().await { let bytes = match chunk_result { Ok(b) => b, Err(e) => { let _ = tx .send(StreamEvent::Error(format!("Stream error: {}", e))) .await; return; } }; buffer.push_str(&String::from_utf8_lossy(&bytes)); // Process complete SSE lines while let Some(line_end) = buffer.find('\n') { let line = buffer[..line_end].trim().to_string(); buffer = buffer[line_end + 1..].to_string(); if line.is_empty() || line.starts_with(':') { continue; } if let Some(data) = line.strip_prefix("data: ") { let data = data.trim(); if data == "[DONE]" { // Stream finished continue; } let chunk: StreamChunk = match serde_json::from_str(data) { Ok(c) => c, Err(_) => continue, }; if let Some(m) = &chunk.model { model_name = m.clone(); } if let Some(u) = chunk.usage { final_usage = Some(u); } if let Some(choices) = &chunk.choices { if let Some(choice) = choices.first() { if let Some(delta) = &choice.delta { // Handle content delta if let Some(content) = &delta.content { if !content.is_empty() { full_content.push_str(content); let _ = tx.send(StreamEvent::Delta(content.clone())).await; } } // Handle tool call deltas (accumulate incrementally) if let Some(tcs) = &delta.tool_calls { has_tool_calls = true; for tc in tcs { let idx = tc.index.unwrap_or(0); // Ensure we have enough slots while tool_call_accum.len() <= idx { tool_call_accum.push(ToolCall { id: String::new(), r#type: "function".to_string(), function: ToolCallFunction { name: String::new(), arguments: String::new(), }, }); } if let Some(id) = &tc.id { tool_call_accum[idx].id = id.clone(); } if let Some(func) = &tc.function { if let Some(name) = &func.name { tool_call_accum[idx].function.name.push_str(name); } if let Some(args) = &func.arguments { tool_call_accum[idx] .function .arguments .push_str(args); } } } } } } } } } } let elapsed_ms = start.elapsed().as_millis() as u64; let (prompt_tokens, completion_tokens, total_tokens) = match &final_usage { Some(u) => ( u.prompt_tokens.unwrap_or(0), u.completion_tokens.unwrap_or(0), u.total_tokens.unwrap_or(0), ), None => (0, 0, 0), }; let stats = CompletionStats { model: model_name, prompt_tokens, completion_tokens, total_tokens, response_ms: elapsed_ms, }; if has_tool_calls && !tool_call_accum.is_empty() { // AI requested tool calls let assistant_msg = ChatMessage { role: "assistant".into(), content: if full_content.is_empty() { None } else { Some(Content::Text(full_content)) }, tool_calls: Some(tool_call_accum), tool_call_id: None, }; let _ = tx.send(StreamEvent::ToolCalls(assistant_msg, stats)).await; } else { // Normal text response completed let _ = tx.send(StreamEvent::Done(stats)).await; } }); rx } /// Build the message history for OpenRouter from stored messages. /// Includes the system prompt as the first message. /// Messages with image data URLs will be sent as multimodal content. pub fn build_chat_history( system_prompt: &str, messages: &[(String, String, bool, Option)], // (sender_name, content, is_ai, image_data_url) ) -> Vec { let mut history = vec![ChatMessage { role: "system".to_string(), content: Some(Content::Text(system_prompt.to_string())), tool_calls: None, tool_call_id: None, }]; for (sender_name, content, is_ai, image_data_url) in messages { if *is_ai { history.push(ChatMessage { role: "assistant".to_string(), content: Some(Content::Text(content.clone())), tool_calls: None, tool_call_id: None, }); } else { let text = if content.is_empty() { format!("[{}] shared an image:", sender_name) } else { format!("[{}]: {}", sender_name, content) }; let msg_content = if let Some(data_url) = image_data_url { Content::Parts(vec![ ContentPart::Text { text }, ContentPart::ImageUrl { image_url: ImageUrlData { url: data_url.clone(), }, }, ]) } else { Content::Text(text) }; history.push(ChatMessage { role: "user".to_string(), content: Some(msg_content), tool_calls: None, tool_call_id: None, }); } } history }