Document server code paths
This commit is contained in:
parent
927d106eae
commit
c37ff79514
@ -1,8 +1,8 @@
|
|||||||
use axum::{extract::State, http::StatusCode, Json};
|
|
||||||
use argon2::{
|
use argon2::{
|
||||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||||
Argon2,
|
Argon2,
|
||||||
};
|
};
|
||||||
|
use axum::{extract::State, http::StatusCode, Json};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
@ -12,6 +12,7 @@ use crate::{
|
|||||||
AppState,
|
AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Create a new password-based account and immediately return a JWT.
|
||||||
pub async fn register(
|
pub async fn register(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Json(body): Json<RegisterRequest>,
|
Json(body): Json<RegisterRequest>,
|
||||||
@ -69,6 +70,7 @@ pub async fn register(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Authenticate an existing password-based account and return a fresh JWT.
|
||||||
pub async fn login(
|
pub async fn login(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Json(body): Json<LoginRequest>,
|
Json(body): Json<LoginRequest>,
|
||||||
@ -87,8 +89,8 @@ pub async fn login(
|
|||||||
|
|
||||||
let (user_id, email, display_name, hash, avatar_url) = user;
|
let (user_id, email, display_name, hash, avatar_url) = user;
|
||||||
|
|
||||||
let parsed_hash = PasswordHash::new(&hash)
|
let parsed_hash =
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
PasswordHash::new(&hash).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
Argon2::default()
|
Argon2::default()
|
||||||
.verify_password(body.password.as_bytes(), &parsed_hash)
|
.verify_password(body.password.as_bytes(), &parsed_hash)
|
||||||
@ -109,6 +111,7 @@ pub async fn login(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the caller's current public profile information.
|
||||||
pub async fn me(
|
pub async fn me(
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ use crate::{
|
|||||||
AppState,
|
AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Response payload for a newly created invite link.
|
||||||
#[derive(serde::Serialize)]
|
#[derive(serde::Serialize)]
|
||||||
pub struct InviteResponse {
|
pub struct InviteResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -20,6 +21,7 @@ pub struct InviteResponse {
|
|||||||
pub invite_url: String,
|
pub invite_url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a one-time invite token for a room member to share.
|
||||||
pub async fn create_invite(
|
pub async fn create_invite(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -46,15 +48,17 @@ pub async fn create_invite(
|
|||||||
.map(char::from)
|
.map(char::from)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
sqlx::query("INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)")
|
sqlx::query(
|
||||||
.bind(&invite_id)
|
"INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)",
|
||||||
.bind(&body.room_id)
|
)
|
||||||
.bind(&auth.user_id)
|
.bind(&invite_id)
|
||||||
.bind(&body.email)
|
.bind(&body.room_id)
|
||||||
.bind(&token)
|
.bind(&auth.user_id)
|
||||||
.execute(&state.db)
|
.bind(&body.email)
|
||||||
.await
|
.bind(&token)
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
.execute(&state.db)
|
||||||
|
.await
|
||||||
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
Ok(Json(InviteResponse {
|
Ok(Json(InviteResponse {
|
||||||
id: invite_id,
|
id: invite_id,
|
||||||
@ -63,11 +67,13 @@ pub async fn create_invite(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Response payload returned after consuming an invite.
|
||||||
#[derive(serde::Serialize)]
|
#[derive(serde::Serialize)]
|
||||||
pub struct AcceptInviteResponse {
|
pub struct AcceptInviteResponse {
|
||||||
pub room_id: String,
|
pub room_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Consume an invite token and add the caller to the room.
|
||||||
pub async fn accept_invite(
|
pub async fn accept_invite(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -89,13 +95,12 @@ pub async fn accept_invite(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify room is not deleted
|
// Verify room is not deleted
|
||||||
let room_active = sqlx::query_scalar::<_, String>(
|
let room_active =
|
||||||
"SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL",
|
sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
|
||||||
)
|
.bind(&room_id)
|
||||||
.bind(&room_id)
|
.fetch_optional(&state.db)
|
||||||
.fetch_optional(&state.db)
|
.await
|
||||||
.await
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
|
||||||
|
|
||||||
if room_active.is_none() {
|
if room_active.is_none() {
|
||||||
return Err((StatusCode::GONE, "This room has been deleted".into()));
|
return Err((StatusCode::GONE, "This room has been deleted".into()));
|
||||||
@ -119,6 +124,7 @@ pub async fn accept_invite(
|
|||||||
Ok(Json(AcceptInviteResponse { room_id }))
|
Ok(Json(AcceptInviteResponse { room_id }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Result of a Nostr-based room invite attempt.
|
||||||
#[derive(serde::Serialize)]
|
#[derive(serde::Serialize)]
|
||||||
pub struct NostrInviteResponse {
|
pub struct NostrInviteResponse {
|
||||||
pub status: String,
|
pub status: String,
|
||||||
@ -126,6 +132,7 @@ pub struct NostrInviteResponse {
|
|||||||
pub display_name: Option<String>,
|
pub display_name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add a user to a room by their Nostr public key if they already have an account.
|
||||||
pub async fn invite_by_nostr(
|
pub async fn invite_by_nostr(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -138,8 +145,13 @@ pub async fn invite_by_nostr(
|
|||||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid npub format".to_string()))?
|
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid npub format".to_string()))?
|
||||||
} else {
|
} else {
|
||||||
// Validate it's valid hex
|
// Validate it's valid hex
|
||||||
if body.nostr_pubkey.len() != 64 || !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit()) {
|
if body.nostr_pubkey.len() != 64
|
||||||
return Err((StatusCode::BAD_REQUEST, "Invalid pubkey: must be 64-char hex or npub".to_string()));
|
|| !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit())
|
||||||
|
{
|
||||||
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Invalid pubkey: must be 64-char hex or npub".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
body.nostr_pubkey.clone()
|
body.nostr_pubkey.clone()
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,3 +1,8 @@
|
|||||||
|
//! HTTP and WebSocket entry points for the server.
|
||||||
|
//!
|
||||||
|
//! Each submodule exposes route handlers that Axum wires into the router in
|
||||||
|
//! `main.rs`.
|
||||||
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
pub mod invites;
|
pub mod invites;
|
||||||
pub mod models;
|
pub mod models;
|
||||||
|
|||||||
@ -1,19 +1,16 @@
|
|||||||
use axum::{
|
use axum::{extract::State, http::StatusCode, Json};
|
||||||
extract::State,
|
|
||||||
http::StatusCode,
|
|
||||||
Json,
|
|
||||||
};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::OnceCell;
|
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::sync::OnceCell;
|
||||||
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
|
||||||
/// Cached model list with expiry.
|
/// Cached model list with expiry.
|
||||||
static MODEL_CACHE: OnceCell<Mutex<CachedModels>> = OnceCell::const_new();
|
static MODEL_CACHE: OnceCell<Mutex<CachedModels>> = OnceCell::const_new();
|
||||||
|
|
||||||
|
/// Process-wide cache for the OpenRouter model catalog.
|
||||||
struct CachedModels {
|
struct CachedModels {
|
||||||
models: Vec<ModelInfo>,
|
models: Vec<ModelInfo>,
|
||||||
fetched_at: Instant,
|
fetched_at: Instant,
|
||||||
@ -21,6 +18,7 @@ struct CachedModels {
|
|||||||
|
|
||||||
const CACHE_TTL: Duration = Duration::from_secs(60 * 30); // 30 minutes
|
const CACHE_TTL: Duration = Duration::from_secs(60 * 30); // 30 minutes
|
||||||
|
|
||||||
|
/// Model metadata exposed to the client for room creation and model selection.
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize)]
|
||||||
pub struct ModelInfo {
|
pub struct ModelInfo {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -56,6 +54,10 @@ struct OpenRouterArchitecture {
|
|||||||
input_modalities: Option<Vec<String>>,
|
input_modalities: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fetch the model catalog directly from OpenRouter.
|
||||||
|
///
|
||||||
|
/// The result is normalized into the smaller `ModelInfo` shape that the client
|
||||||
|
/// UI needs.
|
||||||
async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
|
async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
@ -82,7 +84,8 @@ async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|m| {
|
.map(|m| {
|
||||||
let pricing = m.pricing.as_ref();
|
let pricing = m.pricing.as_ref();
|
||||||
let supports_vision = m.architecture
|
let supports_vision = m
|
||||||
|
.architecture
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|a| a.input_modalities.as_ref())
|
.and_then(|a| a.input_modalities.as_ref())
|
||||||
.map(|mods| mods.iter().any(|m| m == "image"))
|
.map(|mods| mods.iter().any(|m| m == "image"))
|
||||||
@ -102,6 +105,7 @@ async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
|
|||||||
Ok(models)
|
Ok(models)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the cached OpenRouter model list, refreshing it when the cache expires.
|
||||||
pub async fn list_models(
|
pub async fn list_models(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
) -> Result<Json<Vec<ModelInfo>>, (StatusCode, String)> {
|
) -> Result<Json<Vec<ModelInfo>>, (StatusCode, String)> {
|
||||||
|
|||||||
@ -10,6 +10,7 @@ use crate::{
|
|||||||
AppState,
|
AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Claims embedded in the short-lived challenge token used during Nostr login.
|
||||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||||
struct ChallengeClaims {
|
struct ChallengeClaims {
|
||||||
pub nonce: String,
|
pub nonce: String,
|
||||||
@ -28,10 +29,7 @@ pub async fn challenge(
|
|||||||
|
|
||||||
let exp = (chrono::Utc::now().timestamp() + 120) as usize; // 2 minutes
|
let exp = (chrono::Utc::now().timestamp() + 120) as usize; // 2 minutes
|
||||||
|
|
||||||
let claims = ChallengeClaims {
|
let claims = ChallengeClaims { nonce, exp };
|
||||||
nonce,
|
|
||||||
exp,
|
|
||||||
};
|
|
||||||
|
|
||||||
let token = encode(
|
let token = encode(
|
||||||
&Header::default(),
|
&Header::default(),
|
||||||
@ -45,6 +43,7 @@ pub async fn challenge(
|
|||||||
|
|
||||||
/// Simple hex encoder (avoid adding the `hex` crate just for this)
|
/// Simple hex encoder (avoid adding the `hex` crate just for this)
|
||||||
mod hex {
|
mod hex {
|
||||||
|
/// Convert raw bytes into a lowercase hexadecimal string.
|
||||||
pub fn encode(bytes: &[u8]) -> String {
|
pub fn encode(bytes: &[u8]) -> String {
|
||||||
bytes.iter().map(|b| format!("{:02x}", b)).collect()
|
bytes.iter().map(|b| format!("{:02x}", b)).collect()
|
||||||
}
|
}
|
||||||
@ -61,17 +60,29 @@ pub async fn verify(
|
|||||||
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
|
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
|
||||||
&Validation::default(),
|
&Validation::default(),
|
||||||
)
|
)
|
||||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid or expired challenge".to_string()))?;
|
.map_err(|_| {
|
||||||
|
(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Invalid or expired challenge".to_string(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
let nonce = &challenge_data.claims.nonce;
|
let nonce = &challenge_data.claims.nonce;
|
||||||
|
|
||||||
// 2. Deserialize signed_event as nostr::Event
|
// 2. Deserialize signed_event as nostr::Event
|
||||||
let event: Event = serde_json::from_str(&body.signed_event)
|
let event: Event = serde_json::from_str(&body.signed_event).map_err(|e| {
|
||||||
.map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid event JSON: {}", e)))?;
|
(
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
format!("Invalid event JSON: {}", e),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
// 3. Verify Schnorr signature
|
// 3. Verify Schnorr signature
|
||||||
if !event.verify_signature() {
|
if !event.verify_signature() {
|
||||||
return Err((StatusCode::UNAUTHORIZED, "Invalid event signature".to_string()));
|
return Err((
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
"Invalid event signature".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Verify event.content == nonce
|
// 4. Verify event.content == nonce
|
||||||
@ -83,7 +94,10 @@ pub async fn verify(
|
|||||||
let now = chrono::Utc::now().timestamp() as u64;
|
let now = chrono::Utc::now().timestamp() as u64;
|
||||||
let event_ts = event.created_at.as_secs();
|
let event_ts = event.created_at.as_secs();
|
||||||
if now.abs_diff(event_ts) > 300 {
|
if now.abs_diff(event_ts) > 300 {
|
||||||
return Err((StatusCode::BAD_REQUEST, "Event timestamp too far off".to_string()));
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Event timestamp too far off".to_string(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. Extract pubkey hex
|
// 6. Extract pubkey hex
|
||||||
|
|||||||
@ -11,6 +11,7 @@ use crate::{
|
|||||||
AppState,
|
AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Request body for profile updates.
|
||||||
#[derive(Debug, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
pub struct UpdateProfileRequest {
|
pub struct UpdateProfileRequest {
|
||||||
pub display_name: Option<String>,
|
pub display_name: Option<String>,
|
||||||
@ -25,7 +26,10 @@ pub async fn update_profile(
|
|||||||
let display_name = body.display_name.unwrap_or(auth.display_name.clone());
|
let display_name = body.display_name.unwrap_or(auth.display_name.clone());
|
||||||
|
|
||||||
if display_name.trim().is_empty() {
|
if display_name.trim().is_empty() {
|
||||||
return Err((StatusCode::BAD_REQUEST, "Display name cannot be empty".into()));
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Display name cannot be empty".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlx::query("UPDATE users SET display_name = ? WHERE id = ?")
|
sqlx::query("UPDATE users SET display_name = ? WHERE id = ?")
|
||||||
@ -83,7 +87,12 @@ pub async fn upload_avatar(
|
|||||||
"image/jpeg" | "image/jpg" => "jpg",
|
"image/jpeg" | "image/jpg" => "jpg",
|
||||||
"image/gif" => "gif",
|
"image/gif" => "gif",
|
||||||
"image/webp" => "webp",
|
"image/webp" => "webp",
|
||||||
_ => return Err((StatusCode::BAD_REQUEST, "Only PNG, JPG, GIF, and WebP images are allowed".into())),
|
_ => {
|
||||||
|
return Err((
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
"Only PNG, JPG, GIF, and WebP images are allowed".into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let data = field
|
let data = field
|
||||||
@ -130,8 +139,13 @@ pub async fn upload_avatar(
|
|||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
// Issue new token
|
// Issue new token
|
||||||
let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret)
|
let token = create_token(
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
&auth.user_id,
|
||||||
|
&auth.email,
|
||||||
|
&auth.display_name,
|
||||||
|
&state.jwt_secret,
|
||||||
|
)
|
||||||
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
Ok(Json(AuthResponse {
|
Ok(Json(AuthResponse {
|
||||||
token,
|
token,
|
||||||
@ -168,8 +182,13 @@ pub async fn delete_avatar(
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret)
|
let token = create_token(
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
&auth.user_id,
|
||||||
|
&auth.email,
|
||||||
|
&auth.display_name,
|
||||||
|
&state.jwt_secret,
|
||||||
|
)
|
||||||
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
Ok(Json(AuthResponse {
|
Ok(Json(AuthResponse {
|
||||||
token,
|
token,
|
||||||
|
|||||||
@ -8,10 +8,13 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
middleware::auth::AuthUser,
|
middleware::auth::AuthUser,
|
||||||
models::{self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic},
|
models::{
|
||||||
|
self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic,
|
||||||
|
},
|
||||||
AppState,
|
AppState,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Create a room, persist it, and add the creator as the first member.
|
||||||
pub async fn create_room(
|
pub async fn create_room(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -60,6 +63,7 @@ pub async fn create_room(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// List all active rooms the caller belongs to, including current room members.
|
||||||
pub async fn list_rooms(
|
pub async fn list_rooms(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -93,13 +97,15 @@ pub async fn list_rooms(
|
|||||||
created_at: room.created_at,
|
created_at: room.created_at,
|
||||||
members: members
|
members: members
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
.map(
|
||||||
id,
|
|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
||||||
email: models::public_email(&email),
|
id,
|
||||||
display_name,
|
email: models::public_email(&email),
|
||||||
avatar_url,
|
display_name,
|
||||||
nostr_pubkey,
|
avatar_url,
|
||||||
})
|
nostr_pubkey,
|
||||||
|
},
|
||||||
|
)
|
||||||
.collect(),
|
.collect(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -107,6 +113,7 @@ pub async fn list_rooms(
|
|||||||
Ok(Json(result))
|
Ok(Json(result))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return details for a single room after verifying the caller is a member.
|
||||||
pub async fn get_room(
|
pub async fn get_room(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -152,17 +159,20 @@ pub async fn get_room(
|
|||||||
created_at: room.created_at,
|
created_at: room.created_at,
|
||||||
members: members
|
members: members
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
.map(
|
||||||
id,
|
|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
||||||
email: models::public_email(&email),
|
id,
|
||||||
display_name,
|
email: models::public_email(&email),
|
||||||
avatar_url,
|
display_name,
|
||||||
nostr_pubkey,
|
avatar_url,
|
||||||
})
|
nostr_pubkey,
|
||||||
|
},
|
||||||
|
)
|
||||||
.collect(),
|
.collect(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return paginated message history for a room the caller can access.
|
||||||
pub async fn get_messages(
|
pub async fn get_messages(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -208,37 +218,56 @@ pub async fn get_messages(
|
|||||||
}
|
}
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
|
// The SQL query reads newest-first for efficient pagination, but clients
|
||||||
|
// render chat oldest-to-newest, so reverse the rows before serializing.
|
||||||
let payloads: Vec<MessagePayload> = rows
|
let payloads: Vec<MessagePayload> = rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.rev()
|
.rev()
|
||||||
.map(|(id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at, ai_meta_str, image_url, email, avatar_url, hash)| {
|
.map(
|
||||||
let ai_meta = ai_meta_str
|
|(
|
||||||
.as_deref()
|
|
||||||
.and_then(|s| serde_json::from_str::<crate::models::AiMeta>(s).ok());
|
|
||||||
let avatar_hash = email
|
|
||||||
.map(|e| crate::models::gravatar_hash(&e))
|
|
||||||
.unwrap_or_default();
|
|
||||||
MessagePayload {
|
|
||||||
id,
|
id,
|
||||||
room_id,
|
room_id,
|
||||||
sender_id,
|
sender_id,
|
||||||
sender_name,
|
sender_name,
|
||||||
content,
|
content,
|
||||||
mentions: serde_json::from_str(&mentions).unwrap_or_default(),
|
mentions,
|
||||||
is_ai,
|
is_ai,
|
||||||
created_at,
|
created_at,
|
||||||
ai_meta,
|
ai_meta_str,
|
||||||
avatar_hash,
|
|
||||||
avatar_url,
|
|
||||||
image_url,
|
image_url,
|
||||||
|
email,
|
||||||
|
avatar_url,
|
||||||
hash,
|
hash,
|
||||||
}
|
)| {
|
||||||
})
|
let ai_meta = ai_meta_str
|
||||||
|
.as_deref()
|
||||||
|
.and_then(|s| serde_json::from_str::<crate::models::AiMeta>(s).ok());
|
||||||
|
let avatar_hash = email
|
||||||
|
.map(|e| crate::models::gravatar_hash(&e))
|
||||||
|
.unwrap_or_default();
|
||||||
|
MessagePayload {
|
||||||
|
id,
|
||||||
|
room_id,
|
||||||
|
sender_id,
|
||||||
|
sender_name,
|
||||||
|
content,
|
||||||
|
mentions: serde_json::from_str(&mentions).unwrap_or_default(),
|
||||||
|
is_ai,
|
||||||
|
created_at,
|
||||||
|
ai_meta,
|
||||||
|
avatar_hash,
|
||||||
|
avatar_url,
|
||||||
|
image_url,
|
||||||
|
hash,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Ok(Json(payloads))
|
Ok(Json(payloads))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolve a stable message hash into the room that contains it.
|
||||||
pub async fn resolve_message_hash(
|
pub async fn resolve_message_hash(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -258,22 +287,29 @@ pub async fn resolve_message_hash(
|
|||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
match row {
|
match row {
|
||||||
Some((room_id,)) => Ok(Json(serde_json::json!({ "room_id": room_id, "hash": hash }))),
|
Some((room_id,)) => Ok(Json(
|
||||||
None => Err((StatusCode::NOT_FOUND, "Message not found or no access".into())),
|
serde_json::json!({ "room_id": room_id, "hash": hash }),
|
||||||
|
)),
|
||||||
|
None => Err((
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
"Message not found or no access".into(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add the caller to a room directly when they already know its ID.
|
||||||
pub async fn join_room(
|
pub async fn join_room(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
Path(room_id): Path<String>,
|
Path(room_id): Path<String>,
|
||||||
) -> Result<StatusCode, (StatusCode, String)> {
|
) -> Result<StatusCode, (StatusCode, String)> {
|
||||||
// Check room exists
|
// Check room exists
|
||||||
let room_exists = sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
|
let room_exists =
|
||||||
.bind(&room_id)
|
sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
|
||||||
.fetch_optional(&state.db)
|
.bind(&room_id)
|
||||||
.await
|
.fetch_optional(&state.db)
|
||||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
.await
|
||||||
|
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||||
|
|
||||||
if room_exists.is_none() {
|
if room_exists.is_none() {
|
||||||
return Err((StatusCode::NOT_FOUND, "Room not found".into()));
|
return Err((StatusCode::NOT_FOUND, "Room not found".into()));
|
||||||
@ -289,6 +325,7 @@ pub async fn join_room(
|
|||||||
Ok(StatusCode::OK)
|
Ok(StatusCode::OK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Soft-delete a room and broadcast the deletion event to connected members.
|
||||||
pub async fn delete_room(
|
pub async fn delete_room(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -303,7 +340,10 @@ pub async fn delete_room(
|
|||||||
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
|
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
|
||||||
|
|
||||||
if room.created_by != auth.user_id {
|
if room.created_by != auth.user_id {
|
||||||
return Err((StatusCode::FORBIDDEN, "Only the room creator can delete this room".into()));
|
return Err((
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"Only the room creator can delete this room".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Soft-delete
|
// Soft-delete
|
||||||
@ -324,6 +364,7 @@ pub async fn delete_room(
|
|||||||
Ok(StatusCode::OK)
|
Ok(StatusCode::OK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Permanently remove all messages from a room without deleting the room itself.
|
||||||
pub async fn clear_room(
|
pub async fn clear_room(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
auth: AuthUser,
|
auth: AuthUser,
|
||||||
@ -338,7 +379,10 @@ pub async fn clear_room(
|
|||||||
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
|
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
|
||||||
|
|
||||||
if room.created_by != auth.user_id {
|
if room.created_by != auth.user_id {
|
||||||
return Err((StatusCode::FORBIDDEN, "Only the room creator can clear messages".into()));
|
return Err((
|
||||||
|
StatusCode::FORBIDDEN,
|
||||||
|
"Only the room creator can clear messages".into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hard-delete all messages
|
// Hard-delete all messages
|
||||||
|
|||||||
@ -1,18 +1,16 @@
|
|||||||
use axum::{
|
use axum::{extract::Multipart, http::StatusCode, Json};
|
||||||
extract::Multipart,
|
|
||||||
http::StatusCode,
|
|
||||||
Json,
|
|
||||||
};
|
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::middleware::auth::AuthUser;
|
use crate::middleware::auth::AuthUser;
|
||||||
|
|
||||||
|
/// Response returned after a chat image upload succeeds.
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
pub struct UploadResponse {
|
pub struct UploadResponse {
|
||||||
pub url: String,
|
pub url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Accept a multipart chat image upload and store it under `uploads/chat-images`.
|
||||||
pub async fn upload_chat_image(
|
pub async fn upload_chat_image(
|
||||||
_auth: AuthUser,
|
_auth: AuthUser,
|
||||||
mut multipart: Multipart,
|
mut multipart: Multipart,
|
||||||
|
|||||||
@ -1,3 +1,9 @@
|
|||||||
|
//! WebSocket workflow for live chat delivery and AI responses.
|
||||||
|
//!
|
||||||
|
//! This module does two jobs:
|
||||||
|
//! - fan out database-backed room events to subscribed browser sockets
|
||||||
|
//! - turn incoming user chat messages into stored messages and optional AI replies
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{
|
extract::{
|
||||||
ws::{Message, WebSocket},
|
ws::{Message, WebSocket},
|
||||||
@ -24,6 +30,7 @@ pub struct WsQuery {
|
|||||||
token: String,
|
token: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Upgrade an authenticated request into a WebSocket connection.
|
||||||
pub async fn ws_handler(
|
pub async fn ws_handler(
|
||||||
ws: WebSocketUpgrade,
|
ws: WebSocketUpgrade,
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
@ -37,10 +44,19 @@ pub async fn ws_handler(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.sub, claims.display_name, claims.email))
|
ws.on_upgrade(move |socket| {
|
||||||
|
handle_socket(socket, state, claims.sub, claims.display_name, claims.email)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String, display_name: String, email: String) {
|
/// Drive a single WebSocket connection until either the send or receive side ends.
|
||||||
|
async fn handle_socket(
|
||||||
|
socket: WebSocket,
|
||||||
|
state: Arc<AppState>,
|
||||||
|
user_id: String,
|
||||||
|
display_name: String,
|
||||||
|
email: String,
|
||||||
|
) {
|
||||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||||
let mut broadcast_rx = state.tx.subscribe();
|
let mut broadcast_rx = state.tx.subscribe();
|
||||||
|
|
||||||
@ -50,7 +66,8 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
|||||||
|
|
||||||
let rooms_clone = subscribed_rooms.clone();
|
let rooms_clone = subscribed_rooms.clone();
|
||||||
|
|
||||||
// Task: forward broadcast events to this client
|
// Task 1: forward room events from the shared broadcast channel into this
|
||||||
|
// specific socket, but only for rooms the browser subscribed to.
|
||||||
let mut send_task = tokio::spawn(async move {
|
let mut send_task = tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
match broadcast_rx.recv().await {
|
match broadcast_rx.recv().await {
|
||||||
@ -81,7 +98,8 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
|||||||
let email_clone = email.clone();
|
let email_clone = email.clone();
|
||||||
let rooms_clone2 = subscribed_rooms.clone();
|
let rooms_clone2 = subscribed_rooms.clone();
|
||||||
|
|
||||||
// Task: receive messages from client
|
// Task 2: receive commands from the browser and translate them into
|
||||||
|
// database writes, broadcasts, or AI work.
|
||||||
let mut recv_task = tokio::spawn(async move {
|
let mut recv_task = tokio::spawn(async move {
|
||||||
while let Some(Ok(msg)) = ws_rx.next().await {
|
while let Some(Ok(msg)) = ws_rx.next().await {
|
||||||
let text = match msg {
|
let text = match msg {
|
||||||
@ -141,7 +159,7 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Wait for either task to finish, then abort the other
|
// If either half of the connection ends, stop the companion task too.
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
_ = &mut send_task => recv_task.abort(),
|
_ = &mut send_task => recv_task.abort(),
|
||||||
_ = &mut recv_task => send_task.abort(),
|
_ = &mut recv_task => send_task.abort(),
|
||||||
@ -150,6 +168,7 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
|||||||
tracing::info!("WebSocket disconnected: {}", user_id);
|
tracing::info!("WebSocket disconnected: {}", user_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Persist a user message, broadcast it, and optionally generate an AI reply.
|
||||||
async fn handle_send_message(
|
async fn handle_send_message(
|
||||||
state: &Arc<AppState>,
|
state: &Arc<AppState>,
|
||||||
user_id: &str,
|
user_id: &str,
|
||||||
@ -184,13 +203,14 @@ async fn handle_send_message(
|
|||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Look up the sender's custom avatar (if any) for the message payload
|
// Look up the sender's custom avatar (if any) for the message payload
|
||||||
let avatar_url: Option<String> = sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?")
|
let avatar_url: Option<String> =
|
||||||
.bind(user_id)
|
sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?")
|
||||||
.fetch_optional(&state.db)
|
.bind(user_id)
|
||||||
.await
|
.fetch_optional(&state.db)
|
||||||
.ok()
|
.await
|
||||||
.flatten()
|
.ok()
|
||||||
.flatten();
|
.flatten()
|
||||||
|
.flatten();
|
||||||
|
|
||||||
// Broadcast human message
|
// Broadcast human message
|
||||||
let payload = MessagePayload {
|
let payload = MessagePayload {
|
||||||
@ -211,12 +231,11 @@ async fn handle_send_message(
|
|||||||
|
|
||||||
let _ = state.tx.send(BroadcastEvent {
|
let _ = state.tx.send(BroadcastEvent {
|
||||||
room_id: room_id.to_string(),
|
room_id: room_id.to_string(),
|
||||||
message: WsServerMessage::NewMessage {
|
message: WsServerMessage::NewMessage { message: payload },
|
||||||
message: payload,
|
|
||||||
},
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Check if AI should respond
|
// The AI only replies when explicitly mentioned or when the room is set to
|
||||||
|
// auto-reply to every message.
|
||||||
let ai_user_id = "ai-assistant";
|
let ai_user_id = "ai-assistant";
|
||||||
let should_respond = mentions.contains(&ai_user_id.to_string());
|
let should_respond = mentions.contains(&ai_user_id.to_string());
|
||||||
|
|
||||||
@ -254,7 +273,8 @@ async fn handle_send_message(
|
|||||||
.await
|
.await
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
// Process history: encode images as base64 data URLs for OpenRouter
|
// OpenRouter accepts image inputs as data URLs, so local uploads need to be
|
||||||
|
// loaded from disk and encoded before they are sent upstream.
|
||||||
let mut history: Vec<(String, String, bool, Option<String>)> = Vec::new();
|
let mut history: Vec<(String, String, bool, Option<String>)> = Vec::new();
|
||||||
for (sender_name, msg_content, is_ai, msg_image_url) in recent_messages.into_iter().rev() {
|
for (sender_name, msg_content, is_ai, msg_image_url) in recent_messages.into_iter().rev() {
|
||||||
let image_data_url = match &msg_image_url {
|
let image_data_url = match &msg_image_url {
|
||||||
@ -272,7 +292,8 @@ async fn handle_send_message(
|
|||||||
// Pre-generate AI message ID so we can reference it in stream chunks
|
// Pre-generate AI message ID so we can reference it in stream chunks
|
||||||
let ai_msg_id = Uuid::new_v4().to_string();
|
let ai_msg_id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
// Call OpenRouter with tool loop — uses streaming for all rounds
|
// Run the AI in a loop because the model may first request tools, then need
|
||||||
|
// follow-up rounds after those tool results are added to history.
|
||||||
let mut total_prompt_tokens: u32 = 0;
|
let mut total_prompt_tokens: u32 = 0;
|
||||||
let mut total_completion_tokens: u32 = 0;
|
let mut total_completion_tokens: u32 = 0;
|
||||||
let mut total_response_ms: u64 = 0;
|
let mut total_response_ms: u64 = 0;
|
||||||
@ -313,16 +334,24 @@ async fn handle_send_message(
|
|||||||
tracing::info!(
|
tracing::info!(
|
||||||
"AI requesting tool calls (round {}): {:?}",
|
"AI requesting tool calls (round {}): {:?}",
|
||||||
round + 1,
|
round + 1,
|
||||||
assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
|
assistant_msg
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
|
||||||
);
|
);
|
||||||
|
|
||||||
// Add the assistant's tool-call message to history
|
// Preserve the assistant tool-call message so the next round
|
||||||
|
// has the same context the model produced.
|
||||||
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
|
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
|
||||||
chat_history.push(assistant_msg);
|
chat_history.push(assistant_msg);
|
||||||
|
|
||||||
// Execute each tool call and add results
|
// Tool results are fed back into the conversation as
|
||||||
|
// synthetic `tool` messages, matching the upstream API.
|
||||||
for tool_call in &tool_calls {
|
for tool_call in &tool_calls {
|
||||||
let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments);
|
let tool_input = extract_tool_input(
|
||||||
|
&tool_call.function.name,
|
||||||
|
&tool_call.function.arguments,
|
||||||
|
);
|
||||||
|
|
||||||
// Broadcast real-time tool usage event
|
// Broadcast real-time tool usage event
|
||||||
let _ = state.tx.send(BroadcastEvent {
|
let _ = state.tx.send(BroadcastEvent {
|
||||||
@ -362,7 +391,7 @@ async fn handle_send_message(
|
|||||||
tool_call_id: Some(tool_call.id.clone()),
|
tool_call_id: Some(tool_call.id.clone()),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// Continue to next round (tool loop)
|
// Ask the model to continue now that tool output exists.
|
||||||
continue 'tool_loop;
|
continue 'tool_loop;
|
||||||
}
|
}
|
||||||
openrouter::StreamEvent::Done(stats) => {
|
openrouter::StreamEvent::Done(stats) => {
|
||||||
@ -382,9 +411,12 @@ async fn handle_send_message(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we exhausted all rounds without a text response, note it
|
// Guardrail: if the model never produced final prose, store a clear fallback
|
||||||
|
// instead of leaving the client waiting indefinitely.
|
||||||
if ai_response.is_empty() && !had_error {
|
if ai_response.is_empty() && !had_error {
|
||||||
ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string();
|
ai_response =
|
||||||
|
"*I used several tools but couldn't formulate a final response. Please try again.*"
|
||||||
|
.to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Signal stream end so client can finalize rendering
|
// Signal stream end so client can finalize rendering
|
||||||
@ -512,7 +544,15 @@ async fn execute_tool(
|
|||||||
return "Error: search query is required".into();
|
return "Error: search query is required".into();
|
||||||
}
|
}
|
||||||
|
|
||||||
match search::search(search_provider, &query, tavily_api_key, brave_api_key, count).await {
|
match search::search(
|
||||||
|
search_provider,
|
||||||
|
&query,
|
||||||
|
tavily_api_key,
|
||||||
|
brave_api_key,
|
||||||
|
count,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(results) => search::format_results(&results),
|
Ok(results) => search::format_results(&results),
|
||||||
Err(e) => format!("Search error: {}", e),
|
Err(e) => format!("Search error: {}", e),
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,12 @@
|
|||||||
|
//! Application bootstrap for the GroupChat server.
|
||||||
|
//!
|
||||||
|
//! This file is responsible for:
|
||||||
|
//! - loading environment configuration
|
||||||
|
//! - opening and migrating the SQLite database
|
||||||
|
//! - constructing shared application state
|
||||||
|
//! - registering HTTP/WebSocket routes
|
||||||
|
//! - serving the SPA frontend in production
|
||||||
|
|
||||||
mod handlers;
|
mod handlers;
|
||||||
mod middleware;
|
mod middleware;
|
||||||
mod models;
|
mod models;
|
||||||
@ -51,14 +60,8 @@ fn backup_database(database_url: &str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db
|
// Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db
|
||||||
let stem = db_file
|
let stem = db_file.file_stem().and_then(|s| s.to_str()).unwrap_or("db");
|
||||||
.file_stem()
|
let ext = db_file.extension().and_then(|s| s.to_str()).unwrap_or("db");
|
||||||
.and_then(|s| s.to_str())
|
|
||||||
.unwrap_or("db");
|
|
||||||
let ext = db_file
|
|
||||||
.extension()
|
|
||||||
.and_then(|s| s.to_str())
|
|
||||||
.unwrap_or("db");
|
|
||||||
|
|
||||||
let now = chrono::Local::now();
|
let now = chrono::Local::now();
|
||||||
let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext);
|
let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext);
|
||||||
@ -82,11 +85,21 @@ fn backup_database(database_url: &str) {
|
|||||||
let wal_path = format!("{}-wal", db_path);
|
let wal_path = format!("{}-wal", db_path);
|
||||||
let shm_path = format!("{}-shm", db_path);
|
let shm_path = format!("{}-shm", db_path);
|
||||||
if std::path::Path::new(&wal_path).exists() {
|
if std::path::Path::new(&wal_path).exists() {
|
||||||
let wal_backup = backup_dir.join(format!("{}_{}.{}-wal", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
|
let wal_backup = backup_dir.join(format!(
|
||||||
|
"{}_{}.{}-wal",
|
||||||
|
stem,
|
||||||
|
now.format("%Y-%m-%d_%H%M%S"),
|
||||||
|
ext
|
||||||
|
));
|
||||||
let _ = std::fs::copy(&wal_path, &wal_backup);
|
let _ = std::fs::copy(&wal_path, &wal_backup);
|
||||||
}
|
}
|
||||||
if std::path::Path::new(&shm_path).exists() {
|
if std::path::Path::new(&shm_path).exists() {
|
||||||
let shm_backup = backup_dir.join(format!("{}_{}.{}-shm", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
|
let shm_backup = backup_dir.join(format!(
|
||||||
|
"{}_{}.{}-shm",
|
||||||
|
stem,
|
||||||
|
now.format("%Y-%m-%d_%H%M%S"),
|
||||||
|
ext
|
||||||
|
));
|
||||||
let _ = std::fs::copy(&shm_path, &shm_backup);
|
let _ = std::fs::copy(&shm_path, &shm_backup);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,20 +132,35 @@ fn prune_old_backups(backup_dir: &std::path::Path, stem: &str, keep: usize) {
|
|||||||
let to_remove = backups.len() - keep;
|
let to_remove = backups.len() - keep;
|
||||||
for entry in backups.into_iter().take(to_remove) {
|
for entry in backups.into_iter().take(to_remove) {
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
let name = path.file_name().unwrap_or_default().to_string_lossy().to_string();
|
let name = path
|
||||||
|
.file_name()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string();
|
||||||
if let Err(e) = std::fs::remove_file(&path) {
|
if let Err(e) = std::fs::remove_file(&path) {
|
||||||
tracing::warn!("Failed to remove old backup {}: {}", name, e);
|
tracing::warn!("Failed to remove old backup {}: {}", name, e);
|
||||||
} else {
|
} else {
|
||||||
tracing::debug!("Pruned old backup: {}", name);
|
tracing::debug!("Pruned old backup: {}", name);
|
||||||
// Also remove associated WAL/SHM backups
|
// Also remove associated WAL/SHM backups
|
||||||
let wal = path.with_extension(format!("{}-wal", path.extension().unwrap_or_default().to_string_lossy()));
|
let wal = path.with_extension(format!(
|
||||||
let shm = path.with_extension(format!("{}-shm", path.extension().unwrap_or_default().to_string_lossy()));
|
"{}-wal",
|
||||||
|
path.extension().unwrap_or_default().to_string_lossy()
|
||||||
|
));
|
||||||
|
let shm = path.with_extension(format!(
|
||||||
|
"{}-shm",
|
||||||
|
path.extension().unwrap_or_default().to_string_lossy()
|
||||||
|
));
|
||||||
let _ = std::fs::remove_file(&wal);
|
let _ = std::fs::remove_file(&wal);
|
||||||
let _ = std::fs::remove_file(&shm);
|
let _ = std::fs::remove_file(&shm);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Shared state injected into every handler.
|
||||||
|
///
|
||||||
|
/// Axum stores this behind an `Arc`, so handlers can cheaply clone the pointer
|
||||||
|
/// while all requests still talk to the same database pool, API keys, and
|
||||||
|
/// broadcast channel.
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub db: sqlx::SqlitePool,
|
pub db: sqlx::SqlitePool,
|
||||||
pub jwt_secret: String,
|
pub jwt_secret: String,
|
||||||
@ -154,11 +182,15 @@ async fn main() {
|
|||||||
.with(tracing_subscriber::fmt::layer())
|
.with(tracing_subscriber::fmt::layer())
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into());
|
// Load the runtime configuration needed to start the server.
|
||||||
|
let database_url =
|
||||||
|
std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into());
|
||||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
||||||
let openrouter_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set");
|
let openrouter_key =
|
||||||
let search_provider = SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref())
|
std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set");
|
||||||
.unwrap_or_else(|e| panic!("{}", e));
|
let search_provider =
|
||||||
|
SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref())
|
||||||
|
.unwrap_or_else(|e| panic!("{}", e));
|
||||||
let tavily_api_key = std::env::var("TAVILY_API_KEY").ok();
|
let tavily_api_key = std::env::var("TAVILY_API_KEY").ok();
|
||||||
let brave_api_key = std::env::var("BRAVE_API_KEY").ok();
|
let brave_api_key = std::env::var("BRAVE_API_KEY").ok();
|
||||||
|
|
||||||
@ -181,7 +213,8 @@ async fn main() {
|
|||||||
.await
|
.await
|
||||||
.expect("Failed to connect to database");
|
.expect("Failed to connect to database");
|
||||||
|
|
||||||
// Run migrations
|
// Run migrations in order. Each one is written so startup can safely try it
|
||||||
|
// again and skip work that already happened in an earlier run.
|
||||||
let migration_sql = include_str!("../migrations/001_init.sql");
|
let migration_sql = include_str!("../migrations/001_init.sql");
|
||||||
sqlx::raw_sql(migration_sql)
|
sqlx::raw_sql(migration_sql)
|
||||||
.execute(&db)
|
.execute(&db)
|
||||||
@ -282,6 +315,8 @@ async fn main() {
|
|||||||
|
|
||||||
tracing::info!("Database initialized");
|
tracing::info!("Database initialized");
|
||||||
|
|
||||||
|
// WebSocket tasks subscribe to this channel to receive room events without
|
||||||
|
// polling the database.
|
||||||
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096);
|
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096);
|
||||||
|
|
||||||
let state = Arc::new(AppState {
|
let state = Arc::new(AppState {
|
||||||
@ -302,32 +337,61 @@ async fn main() {
|
|||||||
// Serve static files from client dist in production
|
// Serve static files from client dist in production
|
||||||
let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into());
|
let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into());
|
||||||
|
|
||||||
|
// Keep API routes separate from the static-file fallback so `/api/*` and
|
||||||
|
// `/ws` requests never get mistaken for SPA routes.
|
||||||
let api_routes = Router::new()
|
let api_routes = Router::new()
|
||||||
// Auth routes
|
// Auth routes
|
||||||
.route("/api/auth/register", post(handlers::auth::register))
|
.route("/api/auth/register", post(handlers::auth::register))
|
||||||
.route("/api/auth/login", post(handlers::auth::login))
|
.route("/api/auth/login", post(handlers::auth::login))
|
||||||
.route("/api/auth/me", get(handlers::auth::me))
|
.route("/api/auth/me", get(handlers::auth::me))
|
||||||
// Nostr auth routes
|
// Nostr auth routes
|
||||||
.route("/api/auth/nostr/challenge", get(handlers::nostr_auth::challenge))
|
.route(
|
||||||
|
"/api/auth/nostr/challenge",
|
||||||
|
get(handlers::nostr_auth::challenge),
|
||||||
|
)
|
||||||
.route("/api/auth/nostr/verify", post(handlers::nostr_auth::verify))
|
.route("/api/auth/nostr/verify", post(handlers::nostr_auth::verify))
|
||||||
// Profile routes
|
// Profile routes
|
||||||
.route("/api/auth/profile", put(handlers::profile::update_profile))
|
.route("/api/auth/profile", put(handlers::profile::update_profile))
|
||||||
.route("/api/auth/avatar", post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar))
|
.route(
|
||||||
|
"/api/auth/avatar",
|
||||||
|
post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar),
|
||||||
|
)
|
||||||
// Room routes
|
// Room routes
|
||||||
.route("/api/rooms", get(handlers::rooms::list_rooms).post(handlers::rooms::create_room))
|
.route(
|
||||||
.route("/api/rooms/:room_id", get(handlers::rooms::get_room).delete(handlers::rooms::delete_room))
|
"/api/rooms",
|
||||||
.route("/api/rooms/:room_id/messages", get(handlers::rooms::get_messages))
|
get(handlers::rooms::list_rooms).post(handlers::rooms::create_room),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/api/rooms/:room_id",
|
||||||
|
get(handlers::rooms::get_room).delete(handlers::rooms::delete_room),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/api/rooms/:room_id/messages",
|
||||||
|
get(handlers::rooms::get_messages),
|
||||||
|
)
|
||||||
.route("/api/rooms/:room_id/join", post(handlers::rooms::join_room))
|
.route("/api/rooms/:room_id/join", post(handlers::rooms::join_room))
|
||||||
.route("/api/rooms/:room_id/clear", post(handlers::rooms::clear_room))
|
.route(
|
||||||
.route("/api/messages/hash/:hash", get(handlers::rooms::resolve_message_hash))
|
"/api/rooms/:room_id/clear",
|
||||||
|
post(handlers::rooms::clear_room),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/api/messages/hash/:hash",
|
||||||
|
get(handlers::rooms::resolve_message_hash),
|
||||||
|
)
|
||||||
// Upload (chat images)
|
// Upload (chat images)
|
||||||
.route("/api/upload", post(handlers::upload::upload_chat_image))
|
.route("/api/upload", post(handlers::upload::upload_chat_image))
|
||||||
// Models
|
// Models
|
||||||
.route("/api/models", get(handlers::models::list_models))
|
.route("/api/models", get(handlers::models::list_models))
|
||||||
// Invite routes
|
// Invite routes
|
||||||
.route("/api/invites", post(handlers::invites::create_invite))
|
.route("/api/invites", post(handlers::invites::create_invite))
|
||||||
.route("/api/invites/:token/accept", post(handlers::invites::accept_invite))
|
.route(
|
||||||
.route("/api/invites/nostr", post(handlers::invites::invite_by_nostr))
|
"/api/invites/:token/accept",
|
||||||
|
post(handlers::invites::accept_invite),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/api/invites/nostr",
|
||||||
|
post(handlers::invites::invite_by_nostr),
|
||||||
|
)
|
||||||
// Uploaded files (avatars)
|
// Uploaded files (avatars)
|
||||||
.nest_service("/uploads", ServeDir::new("uploads"))
|
.nest_service("/uploads", ServeDir::new("uploads"))
|
||||||
// WebSocket
|
// WebSocket
|
||||||
|
|||||||
@ -1,14 +1,11 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||||
extract::FromRequestParts,
|
|
||||||
http::request::Parts,
|
|
||||||
};
|
|
||||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{models::Claims, AppState};
|
use crate::{models::Claims, AppState};
|
||||||
|
|
||||||
/// Extract authenticated user from JWT in Authorization header
|
/// Authenticated user information extracted from the bearer token.
|
||||||
pub struct AuthUser {
|
pub struct AuthUser {
|
||||||
pub user_id: String,
|
pub user_id: String,
|
||||||
pub email: String,
|
pub email: String,
|
||||||
@ -19,7 +16,15 @@ pub struct AuthUser {
|
|||||||
impl FromRequestParts<Arc<AppState>> for AuthUser {
|
impl FromRequestParts<Arc<AppState>> for AuthUser {
|
||||||
type Rejection = axum::http::StatusCode;
|
type Rejection = axum::http::StatusCode;
|
||||||
|
|
||||||
async fn from_request_parts(parts: &mut Parts, state: &Arc<AppState>) -> Result<Self, Self::Rejection> {
|
/// Read the `Authorization: Bearer <token>` header and decode the JWT.
|
||||||
|
///
|
||||||
|
/// Axum runs this automatically for any handler parameter of type
|
||||||
|
/// `AuthUser`, which keeps individual handlers free from repeated token
|
||||||
|
/// parsing logic.
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut Parts,
|
||||||
|
state: &Arc<AppState>,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
let auth_header = parts
|
let auth_header = parts
|
||||||
.headers
|
.headers
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
@ -41,7 +46,16 @@ impl FromRequestParts<Arc<AppState>> for AuthUser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str) -> Result<String, jsonwebtoken::errors::Error> {
|
/// Create a signed JWT for a logged-in user.
|
||||||
|
///
|
||||||
|
/// The token expires after seven days and carries the small amount of identity
|
||||||
|
/// data the server wants available on every request.
|
||||||
|
pub fn create_token(
|
||||||
|
user_id: &str,
|
||||||
|
email: &str,
|
||||||
|
display_name: &str,
|
||||||
|
secret: &str,
|
||||||
|
) -> Result<String, jsonwebtoken::errors::Error> {
|
||||||
let expiration = chrono::Utc::now()
|
let expiration = chrono::Utc::now()
|
||||||
.checked_add_signed(chrono::Duration::days(7))
|
.checked_add_signed(chrono::Duration::days(7))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -61,6 +75,7 @@ pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Decode and validate a previously issued JWT.
|
||||||
pub fn decode_token(token: &str, secret: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
|
pub fn decode_token(token: &str, secret: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
|
||||||
let token_data = decode::<Claims>(
|
let token_data = decode::<Claims>(
|
||||||
token,
|
token,
|
||||||
|
|||||||
@ -1 +1,3 @@
|
|||||||
|
//! Reusable request-processing layers shared across handlers.
|
||||||
|
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
|
|||||||
@ -1,7 +1,14 @@
|
|||||||
|
//! Core data structures shared across the server.
|
||||||
|
//!
|
||||||
|
//! This file intentionally mixes database row types, HTTP payloads, WebSocket
|
||||||
|
//! payloads, and a few helper functions so the rest of the codebase can import
|
||||||
|
//! common shapes from one place.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
// ── Database models ──
|
// ── Database models ──
|
||||||
|
|
||||||
|
/// Row from the `users` table.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -11,6 +18,7 @@ pub struct User {
|
|||||||
pub created_at: String,
|
pub created_at: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Row from the `rooms` table.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||||
pub struct Room {
|
pub struct Room {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -24,6 +32,7 @@ pub struct Room {
|
|||||||
pub deleted_at: Option<String>,
|
pub deleted_at: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Row from the `messages` table.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -38,6 +47,7 @@ pub struct Message {
|
|||||||
pub hash: Option<String>,
|
pub hash: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Row from the `invites` table.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||||
pub struct Invite {
|
pub struct Invite {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -51,6 +61,7 @@ pub struct Invite {
|
|||||||
|
|
||||||
// ── API request/response types ──
|
// ── API request/response types ──
|
||||||
|
|
||||||
|
/// JSON body expected by the registration endpoint.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct RegisterRequest {
|
pub struct RegisterRequest {
|
||||||
pub email: String,
|
pub email: String,
|
||||||
@ -58,18 +69,21 @@ pub struct RegisterRequest {
|
|||||||
pub display_name: String,
|
pub display_name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// JSON body expected by the login endpoint.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct LoginRequest {
|
pub struct LoginRequest {
|
||||||
pub email: String,
|
pub email: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Standard auth response returned after login, registration, or profile update.
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct AuthResponse {
|
pub struct AuthResponse {
|
||||||
pub token: String,
|
pub token: String,
|
||||||
pub user: UserPublic,
|
pub user: UserPublic,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Public user data safe to return to any authenticated client.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct UserPublic {
|
pub struct UserPublic {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -81,6 +95,7 @@ pub struct UserPublic {
|
|||||||
pub nostr_pubkey: Option<String>,
|
pub nostr_pubkey: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// JSON body used when a user creates a new chat room.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct CreateRoomRequest {
|
pub struct CreateRoomRequest {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
@ -93,10 +108,10 @@ pub struct CreateRoomRequest {
|
|||||||
pub ai_name: String,
|
pub ai_name: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Pick a friendly default AI display name when the creator does not specify one.
|
||||||
fn default_ai_name() -> String {
|
fn default_ai_name() -> String {
|
||||||
let names = [
|
let names = [
|
||||||
"Nova", "Atlas", "Sage", "Echo", "Pixel",
|
"Nova", "Atlas", "Sage", "Echo", "Pixel", "Cosmo", "Ember", "Flux", "Lyra", "Onyx",
|
||||||
"Cosmo", "Ember", "Flux", "Lyra", "Onyx",
|
|
||||||
];
|
];
|
||||||
let idx = std::time::SystemTime::now()
|
let idx = std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
@ -105,10 +120,12 @@ fn default_ai_name() -> String {
|
|||||||
names[idx].to_string()
|
names[idx].to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Default prompt that defines the AI assistant's behavior inside a room.
|
||||||
fn default_system_prompt() -> String {
|
fn default_system_prompt() -> String {
|
||||||
"You are a helpful AI assistant participating in a group chat. Be conversational, helpful, and concise. You can see messages from all participants. When mentioned with @ai, respond helpfully.\n\nYou have access to tools:\n- **web_search**: Search the web for current information. Use this when asked about recent events, news, facts you're unsure about, or anything that needs up-to-date information.\n- **web_fetch**: Fetch and read the content of a web page. Use this when a user shares a URL and wants you to read/summarize it, or when you need more details from a search result.\n\nUse tools proactively when they would help answer the question better. You don't need to ask permission to use them.".to_string()
|
"You are a helpful AI assistant participating in a group chat. Be conversational, helpful, and concise. You can see messages from all participants. When mentioned with @ai, respond helpfully.\n\nYou have access to tools:\n- **web_search**: Search the web for current information. Use this when asked about recent events, news, facts you're unsure about, or anything that needs up-to-date information.\n- **web_fetch**: Fetch and read the content of a web page. Use this when a user shares a URL and wants you to read/summarize it, or when you need more details from a search result.\n\nUse tools proactively when they would help answer the question better. You don't need to ask permission to use them.".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Full room payload returned to the client, including current members.
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct RoomResponse {
|
pub struct RoomResponse {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -122,6 +139,7 @@ pub struct RoomResponse {
|
|||||||
pub members: Vec<UserPublic>,
|
pub members: Vec<UserPublic>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// JSON body for an email-based room invite.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct CreateInviteRequest {
|
pub struct CreateInviteRequest {
|
||||||
pub room_id: String,
|
pub room_id: String,
|
||||||
@ -130,6 +148,7 @@ pub struct CreateInviteRequest {
|
|||||||
|
|
||||||
// ── WebSocket event types ──
|
// ── WebSocket event types ──
|
||||||
|
|
||||||
|
/// Messages the browser can send over the WebSocket connection.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum WsClientMessage {
|
pub enum WsClientMessage {
|
||||||
@ -148,17 +167,14 @@ pub enum WsClientMessage {
|
|||||||
Typing { room_id: String },
|
Typing { room_id: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Messages the server can push to browsers over the WebSocket connection.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum WsServerMessage {
|
pub enum WsServerMessage {
|
||||||
#[serde(rename = "new_message")]
|
#[serde(rename = "new_message")]
|
||||||
NewMessage {
|
NewMessage { message: MessagePayload },
|
||||||
message: MessagePayload,
|
|
||||||
},
|
|
||||||
#[serde(rename = "ai_typing")]
|
#[serde(rename = "ai_typing")]
|
||||||
AiTyping {
|
AiTyping { room_id: String },
|
||||||
room_id: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "user_typing")]
|
#[serde(rename = "user_typing")]
|
||||||
UserTyping {
|
UserTyping {
|
||||||
room_id: String,
|
room_id: String,
|
||||||
@ -166,21 +182,13 @@ pub enum WsServerMessage {
|
|||||||
display_name: String,
|
display_name: String,
|
||||||
},
|
},
|
||||||
#[serde(rename = "error")]
|
#[serde(rename = "error")]
|
||||||
Error {
|
Error { message: String },
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "joined")]
|
#[serde(rename = "joined")]
|
||||||
Joined {
|
Joined { room_id: String },
|
||||||
room_id: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "room_deleted")]
|
#[serde(rename = "room_deleted")]
|
||||||
RoomDeleted {
|
RoomDeleted { room_id: String },
|
||||||
room_id: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "room_cleared")]
|
#[serde(rename = "room_cleared")]
|
||||||
RoomCleared {
|
RoomCleared { room_id: String },
|
||||||
room_id: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "ai_tool_usage")]
|
#[serde(rename = "ai_tool_usage")]
|
||||||
AiToolUsage {
|
AiToolUsage {
|
||||||
room_id: String,
|
room_id: String,
|
||||||
@ -194,12 +202,10 @@ pub enum WsServerMessage {
|
|||||||
delta: String,
|
delta: String,
|
||||||
},
|
},
|
||||||
#[serde(rename = "ai_stream_end")]
|
#[serde(rename = "ai_stream_end")]
|
||||||
AiStreamEnd {
|
AiStreamEnd { room_id: String, message_id: String },
|
||||||
room_id: String,
|
|
||||||
message_id: String,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Message shape sent to clients for history loading and live updates.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct MessagePayload {
|
pub struct MessagePayload {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@ -224,7 +230,7 @@ pub struct MessagePayload {
|
|||||||
|
|
||||||
/// Compute Gravatar-compatible MD5 hash from an email address.
|
/// Compute Gravatar-compatible MD5 hash from an email address.
|
||||||
pub fn gravatar_hash(email: &str) -> String {
|
pub fn gravatar_hash(email: &str) -> String {
|
||||||
use md5::{Md5, Digest};
|
use md5::{Digest, Md5};
|
||||||
let normalized = email.trim().to_lowercase();
|
let normalized = email.trim().to_lowercase();
|
||||||
let result = Md5::digest(normalized.as_bytes());
|
let result = Md5::digest(normalized.as_bytes());
|
||||||
format!("{:x}", result)
|
format!("{:x}", result)
|
||||||
@ -232,13 +238,14 @@ pub fn gravatar_hash(email: &str) -> String {
|
|||||||
|
|
||||||
/// Compute SHA-256 integrity hash from created_at timestamp + message content.
|
/// Compute SHA-256 integrity hash from created_at timestamp + message content.
|
||||||
pub fn message_hash(created_at: &str, content: &str) -> String {
|
pub fn message_hash(created_at: &str, content: &str) -> String {
|
||||||
use sha2::{Sha256, Digest};
|
use sha2::{Digest, Sha256};
|
||||||
let mut hasher = Sha256::new();
|
let mut hasher = Sha256::new();
|
||||||
hasher.update(created_at.as_bytes());
|
hasher.update(created_at.as_bytes());
|
||||||
hasher.update(content.as_bytes());
|
hasher.update(content.as_bytes());
|
||||||
format!("{:x}", hasher.finalize())
|
format!("{:x}", hasher.finalize())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Usage and tool metadata captured for AI-generated messages.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct AiMeta {
|
pub struct AiMeta {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
@ -250,6 +257,7 @@ pub struct AiMeta {
|
|||||||
pub tool_results: Option<Vec<ToolResult>>,
|
pub tool_results: Option<Vec<ToolResult>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// One tool invocation performed while generating an AI answer.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolResult {
|
pub struct ToolResult {
|
||||||
pub tool: String,
|
pub tool: String,
|
||||||
@ -259,6 +267,7 @@ pub struct ToolResult {
|
|||||||
|
|
||||||
// ── Broadcast event (internal channel) ──
|
// ── Broadcast event (internal channel) ──
|
||||||
|
|
||||||
|
/// Internal event sent through a Tokio broadcast channel to WebSocket tasks.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct BroadcastEvent {
|
pub struct BroadcastEvent {
|
||||||
pub room_id: String,
|
pub room_id: String,
|
||||||
@ -267,9 +276,10 @@ pub struct BroadcastEvent {
|
|||||||
|
|
||||||
// ── JWT Claims ──
|
// ── JWT Claims ──
|
||||||
|
|
||||||
|
/// Claims stored inside the server-issued JWT.
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct Claims {
|
pub struct Claims {
|
||||||
pub sub: String, // user_id
|
pub sub: String, // user_id
|
||||||
pub email: String,
|
pub email: String,
|
||||||
pub display_name: String,
|
pub display_name: String,
|
||||||
pub exp: usize,
|
pub exp: usize,
|
||||||
@ -277,6 +287,7 @@ pub struct Claims {
|
|||||||
|
|
||||||
// ── Pagination ──
|
// ── Pagination ──
|
||||||
|
|
||||||
|
/// Common pagination parameters for message history endpoints.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct PaginationParams {
|
pub struct PaginationParams {
|
||||||
#[serde(default = "default_limit")]
|
#[serde(default = "default_limit")]
|
||||||
@ -288,7 +299,7 @@ fn default_limit() -> i64 {
|
|||||||
50
|
50
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns "" if the email is a sentinel nostr: value, otherwise returns it as-is.
|
/// Hide placeholder `nostr:*` emails from normal client responses.
|
||||||
pub fn public_email(email: &str) -> String {
|
pub fn public_email(email: &str) -> String {
|
||||||
if email.starts_with("nostr:") {
|
if email.starts_with("nostr:") {
|
||||||
String::new()
|
String::new()
|
||||||
@ -299,11 +310,13 @@ pub fn public_email(email: &str) -> String {
|
|||||||
|
|
||||||
// ── Nostr auth types ──
|
// ── Nostr auth types ──
|
||||||
|
|
||||||
|
/// Response returned by the Nostr challenge endpoint.
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct NostrChallengeResponse {
|
pub struct NostrChallengeResponse {
|
||||||
pub challenge: String,
|
pub challenge: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// JSON body sent by the client when proving Nostr ownership.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct NostrVerifyRequest {
|
pub struct NostrVerifyRequest {
|
||||||
pub signed_event: String,
|
pub signed_event: String,
|
||||||
@ -312,6 +325,7 @@ pub struct NostrVerifyRequest {
|
|||||||
pub profile_picture: Option<String>,
|
pub profile_picture: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// JSON body for inviting an already-known Nostr user into a room.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct NostrInviteRequest {
|
pub struct NostrInviteRequest {
|
||||||
pub room_id: String,
|
pub room_id: String,
|
||||||
|
|||||||
@ -4,6 +4,7 @@ use crate::services::search::SearchResult;
|
|||||||
|
|
||||||
const BRAVE_SEARCH_URL: &str = "https://api.search.brave.com/res/v1/web/search";
|
const BRAVE_SEARCH_URL: &str = "https://api.search.brave.com/res/v1/web/search";
|
||||||
|
|
||||||
|
/// Partial Brave API response containing only the fields this app needs.
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct BraveResponse {
|
struct BraveResponse {
|
||||||
web: Option<BraveWebResults>,
|
web: Option<BraveWebResults>,
|
||||||
@ -27,11 +28,7 @@ struct BraveResult {
|
|||||||
|
|
||||||
/// Search the web using the Brave Search API.
|
/// Search the web using the Brave Search API.
|
||||||
/// Returns a list of simplified search results.
|
/// Returns a list of simplified search results.
|
||||||
pub async fn search(
|
pub async fn search(query: &str, api_key: &str, count: u8) -> Result<Vec<SearchResult>, String> {
|
||||||
query: &str,
|
|
||||||
api_key: &str,
|
|
||||||
count: u8,
|
|
||||||
) -> Result<Vec<SearchResult>, String> {
|
|
||||||
let count = count.clamp(1, 10);
|
let count = count.clamp(1, 10);
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|||||||
@ -19,9 +19,29 @@ const STRIP_TAGS: &[&str] = &[
|
|||||||
|
|
||||||
/// Block-level tags that should produce newlines in text output.
|
/// Block-level tags that should produce newlines in text output.
|
||||||
const BLOCK_TAGS: &[&str] = &[
|
const BLOCK_TAGS: &[&str] = &[
|
||||||
"p", "div", "h1", "h2", "h3", "h4", "h5", "h6", "li", "br", "tr",
|
"p",
|
||||||
"blockquote", "pre", "section", "article", "main", "header",
|
"div",
|
||||||
"dt", "dd", "figcaption", "table", "thead", "tbody",
|
"h1",
|
||||||
|
"h2",
|
||||||
|
"h3",
|
||||||
|
"h4",
|
||||||
|
"h5",
|
||||||
|
"h6",
|
||||||
|
"li",
|
||||||
|
"br",
|
||||||
|
"tr",
|
||||||
|
"blockquote",
|
||||||
|
"pre",
|
||||||
|
"section",
|
||||||
|
"article",
|
||||||
|
"main",
|
||||||
|
"header",
|
||||||
|
"dt",
|
||||||
|
"dd",
|
||||||
|
"figcaption",
|
||||||
|
"table",
|
||||||
|
"thead",
|
||||||
|
"tbody",
|
||||||
];
|
];
|
||||||
|
|
||||||
/// Fetch a URL and extract its text content.
|
/// Fetch a URL and extract its text content.
|
||||||
|
|||||||
@ -1,3 +1,9 @@
|
|||||||
|
//! Integrations with external systems used by the chat server.
|
||||||
|
//!
|
||||||
|
//! These modules wrap search providers, web page fetching, and the OpenRouter
|
||||||
|
//! chat completion API so the rest of the application can call them with simple
|
||||||
|
//! Rust types.
|
||||||
|
|
||||||
pub mod brave;
|
pub mod brave;
|
||||||
pub mod fetch;
|
pub mod fetch;
|
||||||
pub mod openrouter;
|
pub mod openrouter;
|
||||||
|
|||||||
@ -235,7 +235,9 @@ pub async fn chat_completion_stream(
|
|||||||
{
|
{
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let _ = tx.send(StreamEvent::Error(format!("Request failed: {}", e))).await;
|
let _ = tx
|
||||||
|
.send(StreamEvent::Error(format!("Request failed: {}", e)))
|
||||||
|
.await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -243,7 +245,12 @@ pub async fn chat_completion_stream(
|
|||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let body = response.text().await.unwrap_or_default();
|
let body = response.text().await.unwrap_or_default();
|
||||||
let _ = tx.send(StreamEvent::Error(format!("OpenRouter error {}: {}", status, body))).await;
|
let _ = tx
|
||||||
|
.send(StreamEvent::Error(format!(
|
||||||
|
"OpenRouter error {}: {}",
|
||||||
|
status, body
|
||||||
|
)))
|
||||||
|
.await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,7 +271,9 @@ pub async fn chat_completion_stream(
|
|||||||
let bytes = match chunk_result {
|
let bytes = match chunk_result {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let _ = tx.send(StreamEvent::Error(format!("Stream error: {}", e))).await;
|
let _ = tx
|
||||||
|
.send(StreamEvent::Error(format!("Stream error: {}", e)))
|
||||||
|
.await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -338,7 +347,10 @@ pub async fn chat_completion_stream(
|
|||||||
tool_call_accum[idx].function.name.push_str(name);
|
tool_call_accum[idx].function.name.push_str(name);
|
||||||
}
|
}
|
||||||
if let Some(args) = &func.arguments {
|
if let Some(args) = &func.arguments {
|
||||||
tool_call_accum[idx].function.arguments.push_str(args);
|
tool_call_accum[idx]
|
||||||
|
.function
|
||||||
|
.arguments
|
||||||
|
.push_str(args);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -373,7 +385,11 @@ pub async fn chat_completion_stream(
|
|||||||
// AI requested tool calls
|
// AI requested tool calls
|
||||||
let assistant_msg = ChatMessage {
|
let assistant_msg = ChatMessage {
|
||||||
role: "assistant".into(),
|
role: "assistant".into(),
|
||||||
content: if full_content.is_empty() { None } else { Some(Content::Text(full_content)) },
|
content: if full_content.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(Content::Text(full_content))
|
||||||
|
},
|
||||||
tool_calls: Some(tool_call_accum),
|
tool_calls: Some(tool_call_accum),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
};
|
};
|
||||||
@ -420,7 +436,9 @@ pub fn build_chat_history(
|
|||||||
Content::Parts(vec![
|
Content::Parts(vec![
|
||||||
ContentPart::Text { text },
|
ContentPart::Text { text },
|
||||||
ContentPart::ImageUrl {
|
ContentPart::ImageUrl {
|
||||||
image_url: ImageUrlData { url: data_url.clone() },
|
image_url: ImageUrlData {
|
||||||
|
url: data_url.clone(),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use super::{brave, tavily};
|
use super::{brave, tavily};
|
||||||
|
|
||||||
|
/// Which search backend the AI tool layer should call.
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
pub enum SearchProvider {
|
pub enum SearchProvider {
|
||||||
Tavily,
|
Tavily,
|
||||||
@ -9,8 +10,14 @@ pub enum SearchProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SearchProvider {
|
impl SearchProvider {
|
||||||
|
/// Parse the `SEARCH_PROVIDER` environment variable into a supported variant.
|
||||||
pub fn from_env(value: Option<&str>) -> Result<Self, String> {
|
pub fn from_env(value: Option<&str>) -> Result<Self, String> {
|
||||||
match value.unwrap_or("tavily").trim().to_ascii_lowercase().as_str() {
|
match value
|
||||||
|
.unwrap_or("tavily")
|
||||||
|
.trim()
|
||||||
|
.to_ascii_lowercase()
|
||||||
|
.as_str()
|
||||||
|
{
|
||||||
"tavily" => Ok(Self::Tavily),
|
"tavily" => Ok(Self::Tavily),
|
||||||
"brave" => Ok(Self::Brave),
|
"brave" => Ok(Self::Brave),
|
||||||
other => Err(format!(
|
other => Err(format!(
|
||||||
@ -20,6 +27,7 @@ impl SearchProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the environment variable name required by the selected provider.
|
||||||
pub fn required_key_name(self) -> &'static str {
|
pub fn required_key_name(self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::Tavily => "TAVILY_API_KEY",
|
Self::Tavily => "TAVILY_API_KEY",
|
||||||
@ -28,6 +36,7 @@ impl SearchProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Normalized search result shape shared across providers.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct SearchResult {
|
pub struct SearchResult {
|
||||||
pub title: String,
|
pub title: String,
|
||||||
@ -36,6 +45,7 @@ pub struct SearchResult {
|
|||||||
pub age: Option<String>,
|
pub age: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Dispatch a search request to whichever provider the server is configured to use.
|
||||||
pub async fn search(
|
pub async fn search(
|
||||||
provider: SearchProvider,
|
provider: SearchProvider,
|
||||||
query: &str,
|
query: &str,
|
||||||
@ -59,6 +69,7 @@ pub async fn search(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Turn search results into plain text the AI model can read as tool output.
|
||||||
pub fn format_results(results: &[SearchResult]) -> String {
|
pub fn format_results(results: &[SearchResult]) -> String {
|
||||||
if results.is_empty() {
|
if results.is_empty() {
|
||||||
return "No search results found.".to_string();
|
return "No search results found.".to_string();
|
||||||
|
|||||||
@ -23,11 +23,7 @@ struct TavilyResult {
|
|||||||
published_date: Option<String>,
|
published_date: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn search(
|
pub async fn search(query: &str, api_key: &str, count: u8) -> Result<Vec<SearchResult>, String> {
|
||||||
query: &str,
|
|
||||||
api_key: &str,
|
|
||||||
count: u8,
|
|
||||||
) -> Result<Vec<SearchResult>, String> {
|
|
||||||
let max_results = count.clamp(1, 10);
|
let max_results = count.clamp(1, 10);
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
@ -75,7 +71,10 @@ pub async fn search(
|
|||||||
if first.description.is_empty() {
|
if first.description.is_empty() {
|
||||||
first.description = format!("AI summary: {}", answer);
|
first.description = format!("AI summary: {}", answer);
|
||||||
} else {
|
} else {
|
||||||
first.description = format!("AI summary: {}\nSource excerpt: {}", answer, first.description);
|
first.description = format!(
|
||||||
|
"AI summary: {}\nSource excerpt: {}",
|
||||||
|
answer, first.description
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
results.push(SearchResult {
|
results.push(SearchResult {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user