459 lines
15 KiB
Rust
459 lines
15 KiB
Rust
use futures::StreamExt;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
const OPENROUTER_API_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
|
|
|
|
// ── Request types ──
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ChatRequest {
|
|
model: String,
|
|
messages: Vec<ChatMessage>,
|
|
max_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
tools: Option<Vec<Tool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
stream: Option<bool>,
|
|
}
|
|
|
|
/// Content can be a plain text string or multimodal (text + image) parts.
|
|
/// Serializes to a JSON string or a JSON array, matching the OpenAI/OpenRouter API.
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
#[serde(untagged)]
|
|
pub enum Content {
|
|
Text(String),
|
|
Parts(Vec<ContentPart>),
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
#[serde(tag = "type")]
|
|
pub enum ContentPart {
|
|
#[serde(rename = "text")]
|
|
Text { text: String },
|
|
#[serde(rename = "image_url")]
|
|
ImageUrl { image_url: ImageUrlData },
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ImageUrlData {
|
|
pub url: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ChatMessage {
|
|
pub role: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub content: Option<Content>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
// ── Tool definition types ──
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct Tool {
|
|
pub r#type: String,
|
|
pub function: ToolFunction,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ToolFunction {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
pub r#type: String,
|
|
pub function: ToolCallFunction,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct ToolCallFunction {
|
|
pub name: String,
|
|
pub arguments: String,
|
|
}
|
|
|
|
// ── Response types ──
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct Usage {
|
|
prompt_tokens: Option<u32>,
|
|
completion_tokens: Option<u32>,
|
|
total_tokens: Option<u32>,
|
|
}
|
|
|
|
// ── Streaming response types ──
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct StreamChunk {
|
|
choices: Option<Vec<StreamChoice>>,
|
|
model: Option<String>,
|
|
usage: Option<Usage>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[allow(dead_code)]
|
|
struct StreamChoice {
|
|
delta: Option<StreamDelta>,
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct StreamDelta {
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<StreamToolCall>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[allow(dead_code)]
|
|
struct StreamToolCall {
|
|
index: Option<usize>,
|
|
id: Option<String>,
|
|
#[serde(default)]
|
|
r#type: Option<String>,
|
|
function: Option<StreamToolCallFunction>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct StreamToolCallFunction {
|
|
name: Option<String>,
|
|
arguments: Option<String>,
|
|
}
|
|
|
|
/// Events emitted during a streaming completion.
|
|
#[derive(Debug)]
|
|
pub enum StreamEvent {
|
|
/// A chunk of content text.
|
|
Delta(String),
|
|
/// AI wants to call tools (streaming accumulated the full tool_calls).
|
|
ToolCalls(ChatMessage, CompletionStats),
|
|
/// Stream finished with final stats.
|
|
Done(CompletionStats),
|
|
/// An error occurred.
|
|
Error(String),
|
|
}
|
|
|
|
/// Stats returned alongside an AI completion.
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct CompletionStats {
|
|
pub model: String,
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
pub response_ms: u64,
|
|
}
|
|
|
|
/// Build the tool definitions for web_search and web_fetch.
|
|
pub fn build_tools() -> Vec<Tool> {
|
|
vec![
|
|
Tool {
|
|
r#type: "function".into(),
|
|
function: ToolFunction {
|
|
name: "web_search".into(),
|
|
description: "Search the web for current information. Use this when users ask about recent events, need factual data you're unsure about, or want up-to-date information.".into(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The search query"
|
|
},
|
|
"count": {
|
|
"type": "integer",
|
|
"description": "Number of results (1-10, default 5)"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}),
|
|
},
|
|
},
|
|
Tool {
|
|
r#type: "function".into(),
|
|
function: ToolFunction {
|
|
name: "web_fetch".into(),
|
|
description: "Fetch and read the content of a web page. Use this to read articles, documentation, or any URL shared by users or found in search results.".into(),
|
|
parameters: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"url": {
|
|
"type": "string",
|
|
"description": "The URL to fetch"
|
|
}
|
|
},
|
|
"required": ["url"]
|
|
}),
|
|
},
|
|
},
|
|
]
|
|
}
|
|
|
|
/// Send a streaming chat completion request to OpenRouter.
|
|
/// Sends events via the returned mpsc receiver:
|
|
/// - `StreamEvent::Delta(text)` for each content chunk
|
|
/// - `StreamEvent::ToolCalls(msg, stats)` if the AI wants to call tools
|
|
/// - `StreamEvent::Done(stats)` when the stream finishes with a text response
|
|
/// - `StreamEvent::Error(msg)` on error
|
|
pub async fn chat_completion_stream(
|
|
history: Vec<ChatMessage>,
|
|
model_id: &str,
|
|
api_key: &str,
|
|
tools: Option<Vec<Tool>>,
|
|
) -> tokio::sync::mpsc::Receiver<StreamEvent> {
|
|
let (tx, rx) = tokio::sync::mpsc::channel::<StreamEvent>(256);
|
|
|
|
let model_id = model_id.to_string();
|
|
let api_key = api_key.to_string();
|
|
|
|
tokio::spawn(async move {
|
|
let client = reqwest::Client::new();
|
|
|
|
let request_body = ChatRequest {
|
|
model: model_id.clone(),
|
|
messages: history,
|
|
max_tokens: Some(2048),
|
|
tools,
|
|
stream: Some(true),
|
|
};
|
|
|
|
let start = std::time::Instant::now();
|
|
|
|
let response = match client
|
|
.post(OPENROUTER_API_URL)
|
|
.header("Authorization", format!("Bearer {}", api_key))
|
|
.header("Content-Type", "application/json")
|
|
.header("HTTP-Referer", "http://localhost:3001")
|
|
.header("X-Title", "GroupChat")
|
|
.json(&request_body)
|
|
.send()
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
let _ = tx
|
|
.send(StreamEvent::Error(format!("Request failed: {}", e)))
|
|
.await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
if !response.status().is_success() {
|
|
let status = response.status();
|
|
let body = response.text().await.unwrap_or_default();
|
|
let _ = tx
|
|
.send(StreamEvent::Error(format!(
|
|
"OpenRouter error {}: {}",
|
|
status, body
|
|
)))
|
|
.await;
|
|
return;
|
|
}
|
|
|
|
// Read SSE stream
|
|
let mut byte_stream = response.bytes_stream();
|
|
let mut buffer = String::new();
|
|
let mut full_content = String::new();
|
|
let mut model_name = model_id.clone();
|
|
|
|
// Accumulators for streamed tool calls
|
|
let mut tool_call_accum: Vec<ToolCall> = Vec::new();
|
|
let mut has_tool_calls = false;
|
|
|
|
// Usage from the final chunk (some providers include it)
|
|
let mut final_usage: Option<Usage> = None;
|
|
|
|
while let Some(chunk_result) = byte_stream.next().await {
|
|
let bytes = match chunk_result {
|
|
Ok(b) => b,
|
|
Err(e) => {
|
|
let _ = tx
|
|
.send(StreamEvent::Error(format!("Stream error: {}", e)))
|
|
.await;
|
|
return;
|
|
}
|
|
};
|
|
|
|
buffer.push_str(&String::from_utf8_lossy(&bytes));
|
|
|
|
// Process complete SSE lines
|
|
while let Some(line_end) = buffer.find('\n') {
|
|
let line = buffer[..line_end].trim().to_string();
|
|
buffer = buffer[line_end + 1..].to_string();
|
|
|
|
if line.is_empty() || line.starts_with(':') {
|
|
continue;
|
|
}
|
|
|
|
if let Some(data) = line.strip_prefix("data: ") {
|
|
let data = data.trim();
|
|
|
|
if data == "[DONE]" {
|
|
// Stream finished
|
|
continue;
|
|
}
|
|
|
|
let chunk: StreamChunk = match serde_json::from_str(data) {
|
|
Ok(c) => c,
|
|
Err(_) => continue,
|
|
};
|
|
|
|
if let Some(m) = &chunk.model {
|
|
model_name = m.clone();
|
|
}
|
|
|
|
if let Some(u) = chunk.usage {
|
|
final_usage = Some(u);
|
|
}
|
|
|
|
if let Some(choices) = &chunk.choices {
|
|
if let Some(choice) = choices.first() {
|
|
if let Some(delta) = &choice.delta {
|
|
// Handle content delta
|
|
if let Some(content) = &delta.content {
|
|
if !content.is_empty() {
|
|
full_content.push_str(content);
|
|
let _ = tx.send(StreamEvent::Delta(content.clone())).await;
|
|
}
|
|
}
|
|
|
|
// Handle tool call deltas (accumulate incrementally)
|
|
if let Some(tcs) = &delta.tool_calls {
|
|
has_tool_calls = true;
|
|
for tc in tcs {
|
|
let idx = tc.index.unwrap_or(0);
|
|
|
|
// Ensure we have enough slots
|
|
while tool_call_accum.len() <= idx {
|
|
tool_call_accum.push(ToolCall {
|
|
id: String::new(),
|
|
r#type: "function".to_string(),
|
|
function: ToolCallFunction {
|
|
name: String::new(),
|
|
arguments: String::new(),
|
|
},
|
|
});
|
|
}
|
|
|
|
if let Some(id) = &tc.id {
|
|
tool_call_accum[idx].id = id.clone();
|
|
}
|
|
if let Some(func) = &tc.function {
|
|
if let Some(name) = &func.name {
|
|
tool_call_accum[idx].function.name.push_str(name);
|
|
}
|
|
if let Some(args) = &func.arguments {
|
|
tool_call_accum[idx]
|
|
.function
|
|
.arguments
|
|
.push_str(args);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let elapsed_ms = start.elapsed().as_millis() as u64;
|
|
|
|
let (prompt_tokens, completion_tokens, total_tokens) = match &final_usage {
|
|
Some(u) => (
|
|
u.prompt_tokens.unwrap_or(0),
|
|
u.completion_tokens.unwrap_or(0),
|
|
u.total_tokens.unwrap_or(0),
|
|
),
|
|
None => (0, 0, 0),
|
|
};
|
|
|
|
let stats = CompletionStats {
|
|
model: model_name,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
response_ms: elapsed_ms,
|
|
};
|
|
|
|
if has_tool_calls && !tool_call_accum.is_empty() {
|
|
// AI requested tool calls
|
|
let assistant_msg = ChatMessage {
|
|
role: "assistant".into(),
|
|
content: if full_content.is_empty() {
|
|
None
|
|
} else {
|
|
Some(Content::Text(full_content))
|
|
},
|
|
tool_calls: Some(tool_call_accum),
|
|
tool_call_id: None,
|
|
};
|
|
let _ = tx.send(StreamEvent::ToolCalls(assistant_msg, stats)).await;
|
|
} else {
|
|
// Normal text response completed
|
|
let _ = tx.send(StreamEvent::Done(stats)).await;
|
|
}
|
|
});
|
|
|
|
rx
|
|
}
|
|
|
|
/// Build the message history for OpenRouter from stored messages.
|
|
/// Includes the system prompt as the first message.
|
|
/// Messages with image data URLs will be sent as multimodal content.
|
|
pub fn build_chat_history(
|
|
system_prompt: &str,
|
|
messages: &[(String, String, bool, Option<String>)], // (sender_name, content, is_ai, image_data_url)
|
|
) -> Vec<ChatMessage> {
|
|
let mut history = vec![ChatMessage {
|
|
role: "system".to_string(),
|
|
content: Some(Content::Text(system_prompt.to_string())),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
}];
|
|
|
|
for (sender_name, content, is_ai, image_data_url) in messages {
|
|
if *is_ai {
|
|
history.push(ChatMessage {
|
|
role: "assistant".to_string(),
|
|
content: Some(Content::Text(content.clone())),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
});
|
|
} else {
|
|
let text = if content.is_empty() {
|
|
format!("[{}] shared an image:", sender_name)
|
|
} else {
|
|
format!("[{}]: {}", sender_name, content)
|
|
};
|
|
|
|
let msg_content = if let Some(data_url) = image_data_url {
|
|
Content::Parts(vec![
|
|
ContentPart::Text { text },
|
|
ContentPart::ImageUrl {
|
|
image_url: ImageUrlData {
|
|
url: data_url.clone(),
|
|
},
|
|
},
|
|
])
|
|
} else {
|
|
Content::Text(text)
|
|
};
|
|
|
|
history.push(ChatMessage {
|
|
role: "user".to_string(),
|
|
content: Some(msg_content),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
});
|
|
}
|
|
}
|
|
|
|
history
|
|
}
|