Document server code paths
This commit is contained in:
parent
927d106eae
commit
c37ff79514
@ -1,8 +1,8 @@
|
||||
use axum::{extract::State, http::StatusCode, Json};
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2,
|
||||
};
|
||||
use axum::{extract::State, http::StatusCode, Json};
|
||||
use std::sync::Arc;
|
||||
use uuid::Uuid;
|
||||
|
||||
@ -12,6 +12,7 @@ use crate::{
|
||||
AppState,
|
||||
};
|
||||
|
||||
/// Create a new password-based account and immediately return a JWT.
|
||||
pub async fn register(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<RegisterRequest>,
|
||||
@ -69,6 +70,7 @@ pub async fn register(
|
||||
}))
|
||||
}
|
||||
|
||||
/// Authenticate an existing password-based account and return a fresh JWT.
|
||||
pub async fn login(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<LoginRequest>,
|
||||
@ -87,8 +89,8 @@ pub async fn login(
|
||||
|
||||
let (user_id, email, display_name, hash, avatar_url) = user;
|
||||
|
||||
let parsed_hash = PasswordHash::new(&hash)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
let parsed_hash =
|
||||
PasswordHash::new(&hash).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Argon2::default()
|
||||
.verify_password(body.password.as_bytes(), &parsed_hash)
|
||||
@ -109,6 +111,7 @@ pub async fn login(
|
||||
}))
|
||||
}
|
||||
|
||||
/// Return the caller's current public profile information.
|
||||
pub async fn me(
|
||||
auth: AuthUser,
|
||||
State(state): State<Arc<AppState>>,
|
||||
|
||||
@ -13,6 +13,7 @@ use crate::{
|
||||
AppState,
|
||||
};
|
||||
|
||||
/// Response payload for a newly created invite link.
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct InviteResponse {
|
||||
pub id: String,
|
||||
@ -20,6 +21,7 @@ pub struct InviteResponse {
|
||||
pub invite_url: String,
|
||||
}
|
||||
|
||||
/// Create a one-time invite token for a room member to share.
|
||||
pub async fn create_invite(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -46,15 +48,17 @@ pub async fn create_invite(
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
sqlx::query("INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)")
|
||||
.bind(&invite_id)
|
||||
.bind(&body.room_id)
|
||||
.bind(&auth.user_id)
|
||||
.bind(&body.email)
|
||||
.bind(&token)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
sqlx::query(
|
||||
"INSERT INTO invites (id, room_id, invited_by, email, token) VALUES (?, ?, ?, ?, ?)",
|
||||
)
|
||||
.bind(&invite_id)
|
||||
.bind(&body.room_id)
|
||||
.bind(&auth.user_id)
|
||||
.bind(&body.email)
|
||||
.bind(&token)
|
||||
.execute(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(Json(InviteResponse {
|
||||
id: invite_id,
|
||||
@ -63,11 +67,13 @@ pub async fn create_invite(
|
||||
}))
|
||||
}
|
||||
|
||||
/// Response payload returned after consuming an invite.
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct AcceptInviteResponse {
|
||||
pub room_id: String,
|
||||
}
|
||||
|
||||
/// Consume an invite token and add the caller to the room.
|
||||
pub async fn accept_invite(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -89,13 +95,12 @@ pub async fn accept_invite(
|
||||
}
|
||||
|
||||
// Verify room is not deleted
|
||||
let room_active = sqlx::query_scalar::<_, String>(
|
||||
"SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL",
|
||||
)
|
||||
.bind(&room_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
let room_active =
|
||||
sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
|
||||
.bind(&room_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
if room_active.is_none() {
|
||||
return Err((StatusCode::GONE, "This room has been deleted".into()));
|
||||
@ -119,6 +124,7 @@ pub async fn accept_invite(
|
||||
Ok(Json(AcceptInviteResponse { room_id }))
|
||||
}
|
||||
|
||||
/// Result of a Nostr-based room invite attempt.
|
||||
#[derive(serde::Serialize)]
|
||||
pub struct NostrInviteResponse {
|
||||
pub status: String,
|
||||
@ -126,6 +132,7 @@ pub struct NostrInviteResponse {
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Add a user to a room by their Nostr public key if they already have an account.
|
||||
pub async fn invite_by_nostr(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -138,8 +145,13 @@ pub async fn invite_by_nostr(
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid npub format".to_string()))?
|
||||
} else {
|
||||
// Validate it's valid hex
|
||||
if body.nostr_pubkey.len() != 64 || !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit()) {
|
||||
return Err((StatusCode::BAD_REQUEST, "Invalid pubkey: must be 64-char hex or npub".to_string()));
|
||||
if body.nostr_pubkey.len() != 64
|
||||
|| !body.nostr_pubkey.chars().all(|c| c.is_ascii_hexdigit())
|
||||
{
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Invalid pubkey: must be 64-char hex or npub".to_string(),
|
||||
));
|
||||
}
|
||||
body.nostr_pubkey.clone()
|
||||
};
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
//! HTTP and WebSocket entry points for the server.
|
||||
//!
|
||||
//! Each submodule exposes route handlers that Axum wires into the router in
|
||||
//! `main.rs`.
|
||||
|
||||
pub mod auth;
|
||||
pub mod invites;
|
||||
pub mod models;
|
||||
|
||||
@ -1,19 +1,16 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use axum::{extract::State, http::StatusCode, Json};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::OnceCell;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::OnceCell;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
/// Cached model list with expiry.
|
||||
static MODEL_CACHE: OnceCell<Mutex<CachedModels>> = OnceCell::const_new();
|
||||
|
||||
/// Process-wide cache for the OpenRouter model catalog.
|
||||
struct CachedModels {
|
||||
models: Vec<ModelInfo>,
|
||||
fetched_at: Instant,
|
||||
@ -21,6 +18,7 @@ struct CachedModels {
|
||||
|
||||
const CACHE_TTL: Duration = Duration::from_secs(60 * 30); // 30 minutes
|
||||
|
||||
/// Model metadata exposed to the client for room creation and model selection.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
@ -56,6 +54,10 @@ struct OpenRouterArchitecture {
|
||||
input_modalities: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Fetch the model catalog directly from OpenRouter.
|
||||
///
|
||||
/// The result is normalized into the smaller `ModelInfo` shape that the client
|
||||
/// UI needs.
|
||||
async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
@ -82,7 +84,8 @@ async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
|
||||
.into_iter()
|
||||
.map(|m| {
|
||||
let pricing = m.pricing.as_ref();
|
||||
let supports_vision = m.architecture
|
||||
let supports_vision = m
|
||||
.architecture
|
||||
.as_ref()
|
||||
.and_then(|a| a.input_modalities.as_ref())
|
||||
.map(|mods| mods.iter().any(|m| m == "image"))
|
||||
@ -102,6 +105,7 @@ async fn fetch_models(api_key: &str) -> Result<Vec<ModelInfo>, String> {
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
/// Return the cached OpenRouter model list, refreshing it when the cache expires.
|
||||
pub async fn list_models(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Result<Json<Vec<ModelInfo>>, (StatusCode, String)> {
|
||||
|
||||
@ -10,6 +10,7 @@ use crate::{
|
||||
AppState,
|
||||
};
|
||||
|
||||
/// Claims embedded in the short-lived challenge token used during Nostr login.
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
struct ChallengeClaims {
|
||||
pub nonce: String,
|
||||
@ -28,10 +29,7 @@ pub async fn challenge(
|
||||
|
||||
let exp = (chrono::Utc::now().timestamp() + 120) as usize; // 2 minutes
|
||||
|
||||
let claims = ChallengeClaims {
|
||||
nonce,
|
||||
exp,
|
||||
};
|
||||
let claims = ChallengeClaims { nonce, exp };
|
||||
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
@ -45,6 +43,7 @@ pub async fn challenge(
|
||||
|
||||
/// Simple hex encoder (avoid adding the `hex` crate just for this)
|
||||
mod hex {
|
||||
/// Convert raw bytes into a lowercase hexadecimal string.
|
||||
pub fn encode(bytes: &[u8]) -> String {
|
||||
bytes.iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
@ -61,17 +60,29 @@ pub async fn verify(
|
||||
&DecodingKey::from_secret(state.jwt_secret.as_bytes()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| (StatusCode::BAD_REQUEST, "Invalid or expired challenge".to_string()))?;
|
||||
.map_err(|_| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Invalid or expired challenge".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let nonce = &challenge_data.claims.nonce;
|
||||
|
||||
// 2. Deserialize signed_event as nostr::Event
|
||||
let event: Event = serde_json::from_str(&body.signed_event)
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid event JSON: {}", e)))?;
|
||||
let event: Event = serde_json::from_str(&body.signed_event).map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid event JSON: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
// 3. Verify Schnorr signature
|
||||
if !event.verify_signature() {
|
||||
return Err((StatusCode::UNAUTHORIZED, "Invalid event signature".to_string()));
|
||||
return Err((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
"Invalid event signature".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// 4. Verify event.content == nonce
|
||||
@ -83,7 +94,10 @@ pub async fn verify(
|
||||
let now = chrono::Utc::now().timestamp() as u64;
|
||||
let event_ts = event.created_at.as_secs();
|
||||
if now.abs_diff(event_ts) > 300 {
|
||||
return Err((StatusCode::BAD_REQUEST, "Event timestamp too far off".to_string()));
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Event timestamp too far off".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// 6. Extract pubkey hex
|
||||
|
||||
@ -11,6 +11,7 @@ use crate::{
|
||||
AppState,
|
||||
};
|
||||
|
||||
/// Request body for profile updates.
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
pub struct UpdateProfileRequest {
|
||||
pub display_name: Option<String>,
|
||||
@ -25,7 +26,10 @@ pub async fn update_profile(
|
||||
let display_name = body.display_name.unwrap_or(auth.display_name.clone());
|
||||
|
||||
if display_name.trim().is_empty() {
|
||||
return Err((StatusCode::BAD_REQUEST, "Display name cannot be empty".into()));
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Display name cannot be empty".into(),
|
||||
));
|
||||
}
|
||||
|
||||
sqlx::query("UPDATE users SET display_name = ? WHERE id = ?")
|
||||
@ -83,7 +87,12 @@ pub async fn upload_avatar(
|
||||
"image/jpeg" | "image/jpg" => "jpg",
|
||||
"image/gif" => "gif",
|
||||
"image/webp" => "webp",
|
||||
_ => return Err((StatusCode::BAD_REQUEST, "Only PNG, JPG, GIF, and WebP images are allowed".into())),
|
||||
_ => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Only PNG, JPG, GIF, and WebP images are allowed".into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let data = field
|
||||
@ -130,8 +139,13 @@ pub async fn upload_avatar(
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// Issue new token
|
||||
let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
let token = create_token(
|
||||
&auth.user_id,
|
||||
&auth.email,
|
||||
&auth.display_name,
|
||||
&state.jwt_secret,
|
||||
)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(Json(AuthResponse {
|
||||
token,
|
||||
@ -168,8 +182,13 @@ pub async fn delete_avatar(
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
let token = create_token(&auth.user_id, &auth.email, &auth.display_name, &state.jwt_secret)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
let token = create_token(
|
||||
&auth.user_id,
|
||||
&auth.email,
|
||||
&auth.display_name,
|
||||
&state.jwt_secret,
|
||||
)
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
Ok(Json(AuthResponse {
|
||||
token,
|
||||
|
||||
@ -8,10 +8,13 @@ use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
middleware::auth::AuthUser,
|
||||
models::{self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic},
|
||||
models::{
|
||||
self, CreateRoomRequest, MessagePayload, PaginationParams, Room, RoomResponse, UserPublic,
|
||||
},
|
||||
AppState,
|
||||
};
|
||||
|
||||
/// Create a room, persist it, and add the creator as the first member.
|
||||
pub async fn create_room(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -60,6 +63,7 @@ pub async fn create_room(
|
||||
}))
|
||||
}
|
||||
|
||||
/// List all active rooms the caller belongs to, including current room members.
|
||||
pub async fn list_rooms(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -93,13 +97,15 @@ pub async fn list_rooms(
|
||||
created_at: room.created_at,
|
||||
members: members
|
||||
.into_iter()
|
||||
.map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
||||
id,
|
||||
email: models::public_email(&email),
|
||||
display_name,
|
||||
avatar_url,
|
||||
nostr_pubkey,
|
||||
})
|
||||
.map(
|
||||
|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
||||
id,
|
||||
email: models::public_email(&email),
|
||||
display_name,
|
||||
avatar_url,
|
||||
nostr_pubkey,
|
||||
},
|
||||
)
|
||||
.collect(),
|
||||
});
|
||||
}
|
||||
@ -107,6 +113,7 @@ pub async fn list_rooms(
|
||||
Ok(Json(result))
|
||||
}
|
||||
|
||||
/// Return details for a single room after verifying the caller is a member.
|
||||
pub async fn get_room(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -152,17 +159,20 @@ pub async fn get_room(
|
||||
created_at: room.created_at,
|
||||
members: members
|
||||
.into_iter()
|
||||
.map(|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
||||
id,
|
||||
email: models::public_email(&email),
|
||||
display_name,
|
||||
avatar_url,
|
||||
nostr_pubkey,
|
||||
})
|
||||
.map(
|
||||
|(id, email, display_name, avatar_url, nostr_pubkey)| UserPublic {
|
||||
id,
|
||||
email: models::public_email(&email),
|
||||
display_name,
|
||||
avatar_url,
|
||||
nostr_pubkey,
|
||||
},
|
||||
)
|
||||
.collect(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// Return paginated message history for a room the caller can access.
|
||||
pub async fn get_messages(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -208,37 +218,56 @@ pub async fn get_messages(
|
||||
}
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
// The SQL query reads newest-first for efficient pagination, but clients
|
||||
// render chat oldest-to-newest, so reverse the rows before serializing.
|
||||
let payloads: Vec<MessagePayload> = rows
|
||||
.into_iter()
|
||||
.rev()
|
||||
.map(|(id, room_id, sender_id, sender_name, content, mentions, is_ai, created_at, ai_meta_str, image_url, email, avatar_url, hash)| {
|
||||
let ai_meta = ai_meta_str
|
||||
.as_deref()
|
||||
.and_then(|s| serde_json::from_str::<crate::models::AiMeta>(s).ok());
|
||||
let avatar_hash = email
|
||||
.map(|e| crate::models::gravatar_hash(&e))
|
||||
.unwrap_or_default();
|
||||
MessagePayload {
|
||||
.map(
|
||||
|(
|
||||
id,
|
||||
room_id,
|
||||
sender_id,
|
||||
sender_name,
|
||||
content,
|
||||
mentions: serde_json::from_str(&mentions).unwrap_or_default(),
|
||||
mentions,
|
||||
is_ai,
|
||||
created_at,
|
||||
ai_meta,
|
||||
avatar_hash,
|
||||
avatar_url,
|
||||
ai_meta_str,
|
||||
image_url,
|
||||
email,
|
||||
avatar_url,
|
||||
hash,
|
||||
}
|
||||
})
|
||||
)| {
|
||||
let ai_meta = ai_meta_str
|
||||
.as_deref()
|
||||
.and_then(|s| serde_json::from_str::<crate::models::AiMeta>(s).ok());
|
||||
let avatar_hash = email
|
||||
.map(|e| crate::models::gravatar_hash(&e))
|
||||
.unwrap_or_default();
|
||||
MessagePayload {
|
||||
id,
|
||||
room_id,
|
||||
sender_id,
|
||||
sender_name,
|
||||
content,
|
||||
mentions: serde_json::from_str(&mentions).unwrap_or_default(),
|
||||
is_ai,
|
||||
created_at,
|
||||
ai_meta,
|
||||
avatar_hash,
|
||||
avatar_url,
|
||||
image_url,
|
||||
hash,
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(Json(payloads))
|
||||
}
|
||||
|
||||
/// Resolve a stable message hash into the room that contains it.
|
||||
pub async fn resolve_message_hash(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -258,22 +287,29 @@ pub async fn resolve_message_hash(
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
match row {
|
||||
Some((room_id,)) => Ok(Json(serde_json::json!({ "room_id": room_id, "hash": hash }))),
|
||||
None => Err((StatusCode::NOT_FOUND, "Message not found or no access".into())),
|
||||
Some((room_id,)) => Ok(Json(
|
||||
serde_json::json!({ "room_id": room_id, "hash": hash }),
|
||||
)),
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
"Message not found or no access".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add the caller to a room directly when they already know its ID.
|
||||
pub async fn join_room(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
Path(room_id): Path<String>,
|
||||
) -> Result<StatusCode, (StatusCode, String)> {
|
||||
// Check room exists
|
||||
let room_exists = sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
|
||||
.bind(&room_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
let room_exists =
|
||||
sqlx::query_scalar::<_, String>("SELECT id FROM rooms WHERE id = ? AND deleted_at IS NULL")
|
||||
.bind(&room_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
if room_exists.is_none() {
|
||||
return Err((StatusCode::NOT_FOUND, "Room not found".into()));
|
||||
@ -289,6 +325,7 @@ pub async fn join_room(
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
/// Soft-delete a room and broadcast the deletion event to connected members.
|
||||
pub async fn delete_room(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -303,7 +340,10 @@ pub async fn delete_room(
|
||||
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
|
||||
|
||||
if room.created_by != auth.user_id {
|
||||
return Err((StatusCode::FORBIDDEN, "Only the room creator can delete this room".into()));
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
"Only the room creator can delete this room".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Soft-delete
|
||||
@ -324,6 +364,7 @@ pub async fn delete_room(
|
||||
Ok(StatusCode::OK)
|
||||
}
|
||||
|
||||
/// Permanently remove all messages from a room without deleting the room itself.
|
||||
pub async fn clear_room(
|
||||
State(state): State<Arc<AppState>>,
|
||||
auth: AuthUser,
|
||||
@ -338,7 +379,10 @@ pub async fn clear_room(
|
||||
.ok_or((StatusCode::NOT_FOUND, "Room not found".into()))?;
|
||||
|
||||
if room.created_by != auth.user_id {
|
||||
return Err((StatusCode::FORBIDDEN, "Only the room creator can clear messages".into()));
|
||||
return Err((
|
||||
StatusCode::FORBIDDEN,
|
||||
"Only the room creator can clear messages".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Hard-delete all messages
|
||||
|
||||
@ -1,18 +1,16 @@
|
||||
use axum::{
|
||||
extract::Multipart,
|
||||
http::StatusCode,
|
||||
Json,
|
||||
};
|
||||
use axum::{extract::Multipart, http::StatusCode, Json};
|
||||
use serde::Serialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::middleware::auth::AuthUser;
|
||||
|
||||
/// Response returned after a chat image upload succeeds.
|
||||
#[derive(Serialize)]
|
||||
pub struct UploadResponse {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
/// Accept a multipart chat image upload and store it under `uploads/chat-images`.
|
||||
pub async fn upload_chat_image(
|
||||
_auth: AuthUser,
|
||||
mut multipart: Multipart,
|
||||
|
||||
@ -1,3 +1,9 @@
|
||||
//! WebSocket workflow for live chat delivery and AI responses.
|
||||
//!
|
||||
//! This module does two jobs:
|
||||
//! - fan out database-backed room events to subscribed browser sockets
|
||||
//! - turn incoming user chat messages into stored messages and optional AI replies
|
||||
|
||||
use axum::{
|
||||
extract::{
|
||||
ws::{Message, WebSocket},
|
||||
@ -24,6 +30,7 @@ pub struct WsQuery {
|
||||
token: String,
|
||||
}
|
||||
|
||||
/// Upgrade an authenticated request into a WebSocket connection.
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<Arc<AppState>>,
|
||||
@ -37,10 +44,19 @@ pub async fn ws_handler(
|
||||
}
|
||||
};
|
||||
|
||||
ws.on_upgrade(move |socket| handle_socket(socket, state, claims.sub, claims.display_name, claims.email))
|
||||
ws.on_upgrade(move |socket| {
|
||||
handle_socket(socket, state, claims.sub, claims.display_name, claims.email)
|
||||
})
|
||||
}
|
||||
|
||||
async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String, display_name: String, email: String) {
|
||||
/// Drive a single WebSocket connection until either the send or receive side ends.
|
||||
async fn handle_socket(
|
||||
socket: WebSocket,
|
||||
state: Arc<AppState>,
|
||||
user_id: String,
|
||||
display_name: String,
|
||||
email: String,
|
||||
) {
|
||||
let (mut ws_tx, mut ws_rx) = socket.split();
|
||||
let mut broadcast_rx = state.tx.subscribe();
|
||||
|
||||
@ -50,7 +66,8 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
||||
|
||||
let rooms_clone = subscribed_rooms.clone();
|
||||
|
||||
// Task: forward broadcast events to this client
|
||||
// Task 1: forward room events from the shared broadcast channel into this
|
||||
// specific socket, but only for rooms the browser subscribed to.
|
||||
let mut send_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match broadcast_rx.recv().await {
|
||||
@ -81,7 +98,8 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
||||
let email_clone = email.clone();
|
||||
let rooms_clone2 = subscribed_rooms.clone();
|
||||
|
||||
// Task: receive messages from client
|
||||
// Task 2: receive commands from the browser and translate them into
|
||||
// database writes, broadcasts, or AI work.
|
||||
let mut recv_task = tokio::spawn(async move {
|
||||
while let Some(Ok(msg)) = ws_rx.next().await {
|
||||
let text = match msg {
|
||||
@ -141,7 +159,7 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
||||
}
|
||||
});
|
||||
|
||||
// Wait for either task to finish, then abort the other
|
||||
// If either half of the connection ends, stop the companion task too.
|
||||
tokio::select! {
|
||||
_ = &mut send_task => recv_task.abort(),
|
||||
_ = &mut recv_task => send_task.abort(),
|
||||
@ -150,6 +168,7 @@ async fn handle_socket(socket: WebSocket, state: Arc<AppState>, user_id: String,
|
||||
tracing::info!("WebSocket disconnected: {}", user_id);
|
||||
}
|
||||
|
||||
/// Persist a user message, broadcast it, and optionally generate an AI reply.
|
||||
async fn handle_send_message(
|
||||
state: &Arc<AppState>,
|
||||
user_id: &str,
|
||||
@ -184,13 +203,14 @@ async fn handle_send_message(
|
||||
.await;
|
||||
|
||||
// Look up the sender's custom avatar (if any) for the message payload
|
||||
let avatar_url: Option<String> = sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.flatten();
|
||||
let avatar_url: Option<String> =
|
||||
sqlx::query_scalar("SELECT avatar_url FROM users WHERE id = ?")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&state.db)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.flatten();
|
||||
|
||||
// Broadcast human message
|
||||
let payload = MessagePayload {
|
||||
@ -211,12 +231,11 @@ async fn handle_send_message(
|
||||
|
||||
let _ = state.tx.send(BroadcastEvent {
|
||||
room_id: room_id.to_string(),
|
||||
message: WsServerMessage::NewMessage {
|
||||
message: payload,
|
||||
},
|
||||
message: WsServerMessage::NewMessage { message: payload },
|
||||
});
|
||||
|
||||
// Check if AI should respond
|
||||
// The AI only replies when explicitly mentioned or when the room is set to
|
||||
// auto-reply to every message.
|
||||
let ai_user_id = "ai-assistant";
|
||||
let should_respond = mentions.contains(&ai_user_id.to_string());
|
||||
|
||||
@ -254,7 +273,8 @@ async fn handle_send_message(
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Process history: encode images as base64 data URLs for OpenRouter
|
||||
// OpenRouter accepts image inputs as data URLs, so local uploads need to be
|
||||
// loaded from disk and encoded before they are sent upstream.
|
||||
let mut history: Vec<(String, String, bool, Option<String>)> = Vec::new();
|
||||
for (sender_name, msg_content, is_ai, msg_image_url) in recent_messages.into_iter().rev() {
|
||||
let image_data_url = match &msg_image_url {
|
||||
@ -272,7 +292,8 @@ async fn handle_send_message(
|
||||
// Pre-generate AI message ID so we can reference it in stream chunks
|
||||
let ai_msg_id = Uuid::new_v4().to_string();
|
||||
|
||||
// Call OpenRouter with tool loop — uses streaming for all rounds
|
||||
// Run the AI in a loop because the model may first request tools, then need
|
||||
// follow-up rounds after those tool results are added to history.
|
||||
let mut total_prompt_tokens: u32 = 0;
|
||||
let mut total_completion_tokens: u32 = 0;
|
||||
let mut total_response_ms: u64 = 0;
|
||||
@ -313,16 +334,24 @@ async fn handle_send_message(
|
||||
tracing::info!(
|
||||
"AI requesting tool calls (round {}): {:?}",
|
||||
round + 1,
|
||||
assistant_msg.tool_calls.as_ref().map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
|
||||
assistant_msg
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map(|tc| tc.iter().map(|t| &t.function.name).collect::<Vec<_>>())
|
||||
);
|
||||
|
||||
// Add the assistant's tool-call message to history
|
||||
// Preserve the assistant tool-call message so the next round
|
||||
// has the same context the model produced.
|
||||
let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
|
||||
chat_history.push(assistant_msg);
|
||||
|
||||
// Execute each tool call and add results
|
||||
// Tool results are fed back into the conversation as
|
||||
// synthetic `tool` messages, matching the upstream API.
|
||||
for tool_call in &tool_calls {
|
||||
let tool_input = extract_tool_input(&tool_call.function.name, &tool_call.function.arguments);
|
||||
let tool_input = extract_tool_input(
|
||||
&tool_call.function.name,
|
||||
&tool_call.function.arguments,
|
||||
);
|
||||
|
||||
// Broadcast real-time tool usage event
|
||||
let _ = state.tx.send(BroadcastEvent {
|
||||
@ -362,7 +391,7 @@ async fn handle_send_message(
|
||||
tool_call_id: Some(tool_call.id.clone()),
|
||||
});
|
||||
}
|
||||
// Continue to next round (tool loop)
|
||||
// Ask the model to continue now that tool output exists.
|
||||
continue 'tool_loop;
|
||||
}
|
||||
openrouter::StreamEvent::Done(stats) => {
|
||||
@ -382,9 +411,12 @@ async fn handle_send_message(
|
||||
}
|
||||
}
|
||||
|
||||
// If we exhausted all rounds without a text response, note it
|
||||
// Guardrail: if the model never produced final prose, store a clear fallback
|
||||
// instead of leaving the client waiting indefinitely.
|
||||
if ai_response.is_empty() && !had_error {
|
||||
ai_response = "*I used several tools but couldn't formulate a final response. Please try again.*".to_string();
|
||||
ai_response =
|
||||
"*I used several tools but couldn't formulate a final response. Please try again.*"
|
||||
.to_string();
|
||||
}
|
||||
|
||||
// Signal stream end so client can finalize rendering
|
||||
@ -512,7 +544,15 @@ async fn execute_tool(
|
||||
return "Error: search query is required".into();
|
||||
}
|
||||
|
||||
match search::search(search_provider, &query, tavily_api_key, brave_api_key, count).await {
|
||||
match search::search(
|
||||
search_provider,
|
||||
&query,
|
||||
tavily_api_key,
|
||||
brave_api_key,
|
||||
count,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(results) => search::format_results(&results),
|
||||
Err(e) => format!("Search error: {}", e),
|
||||
}
|
||||
|
||||
@ -1,3 +1,12 @@
|
||||
//! Application bootstrap for the GroupChat server.
|
||||
//!
|
||||
//! This file is responsible for:
|
||||
//! - loading environment configuration
|
||||
//! - opening and migrating the SQLite database
|
||||
//! - constructing shared application state
|
||||
//! - registering HTTP/WebSocket routes
|
||||
//! - serving the SPA frontend in production
|
||||
|
||||
mod handlers;
|
||||
mod middleware;
|
||||
mod models;
|
||||
@ -51,14 +60,8 @@ fn backup_database(database_url: &str) {
|
||||
}
|
||||
|
||||
// Build timestamped backup filename: chat.db -> chat_2026-03-09_143022.db
|
||||
let stem = db_file
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("db");
|
||||
let ext = db_file
|
||||
.extension()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("db");
|
||||
let stem = db_file.file_stem().and_then(|s| s.to_str()).unwrap_or("db");
|
||||
let ext = db_file.extension().and_then(|s| s.to_str()).unwrap_or("db");
|
||||
|
||||
let now = chrono::Local::now();
|
||||
let backup_name = format!("{}_{}.{}", stem, now.format("%Y-%m-%d_%H%M%S"), ext);
|
||||
@ -82,11 +85,21 @@ fn backup_database(database_url: &str) {
|
||||
let wal_path = format!("{}-wal", db_path);
|
||||
let shm_path = format!("{}-shm", db_path);
|
||||
if std::path::Path::new(&wal_path).exists() {
|
||||
let wal_backup = backup_dir.join(format!("{}_{}.{}-wal", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
|
||||
let wal_backup = backup_dir.join(format!(
|
||||
"{}_{}.{}-wal",
|
||||
stem,
|
||||
now.format("%Y-%m-%d_%H%M%S"),
|
||||
ext
|
||||
));
|
||||
let _ = std::fs::copy(&wal_path, &wal_backup);
|
||||
}
|
||||
if std::path::Path::new(&shm_path).exists() {
|
||||
let shm_backup = backup_dir.join(format!("{}_{}.{}-shm", stem, now.format("%Y-%m-%d_%H%M%S"), ext));
|
||||
let shm_backup = backup_dir.join(format!(
|
||||
"{}_{}.{}-shm",
|
||||
stem,
|
||||
now.format("%Y-%m-%d_%H%M%S"),
|
||||
ext
|
||||
));
|
||||
let _ = std::fs::copy(&shm_path, &shm_backup);
|
||||
}
|
||||
|
||||
@ -119,20 +132,35 @@ fn prune_old_backups(backup_dir: &std::path::Path, stem: &str, keep: usize) {
|
||||
let to_remove = backups.len() - keep;
|
||||
for entry in backups.into_iter().take(to_remove) {
|
||||
let path = entry.path();
|
||||
let name = path.file_name().unwrap_or_default().to_string_lossy().to_string();
|
||||
let name = path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
if let Err(e) = std::fs::remove_file(&path) {
|
||||
tracing::warn!("Failed to remove old backup {}: {}", name, e);
|
||||
} else {
|
||||
tracing::debug!("Pruned old backup: {}", name);
|
||||
// Also remove associated WAL/SHM backups
|
||||
let wal = path.with_extension(format!("{}-wal", path.extension().unwrap_or_default().to_string_lossy()));
|
||||
let shm = path.with_extension(format!("{}-shm", path.extension().unwrap_or_default().to_string_lossy()));
|
||||
let wal = path.with_extension(format!(
|
||||
"{}-wal",
|
||||
path.extension().unwrap_or_default().to_string_lossy()
|
||||
));
|
||||
let shm = path.with_extension(format!(
|
||||
"{}-shm",
|
||||
path.extension().unwrap_or_default().to_string_lossy()
|
||||
));
|
||||
let _ = std::fs::remove_file(&wal);
|
||||
let _ = std::fs::remove_file(&shm);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared state injected into every handler.
|
||||
///
|
||||
/// Axum stores this behind an `Arc`, so handlers can cheaply clone the pointer
|
||||
/// while all requests still talk to the same database pool, API keys, and
|
||||
/// broadcast channel.
|
||||
pub struct AppState {
|
||||
pub db: sqlx::SqlitePool,
|
||||
pub jwt_secret: String,
|
||||
@ -154,11 +182,15 @@ async fn main() {
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into());
|
||||
// Load the runtime configuration needed to start the server.
|
||||
let database_url =
|
||||
std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite:chat.db?mode=rwc".into());
|
||||
let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| "dev-secret-change-me".into());
|
||||
let openrouter_key = std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set");
|
||||
let search_provider = SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref())
|
||||
.unwrap_or_else(|e| panic!("{}", e));
|
||||
let openrouter_key =
|
||||
std::env::var("OPENROUTER_API_KEY").expect("OPENROUTER_API_KEY must be set");
|
||||
let search_provider =
|
||||
SearchProvider::from_env(std::env::var("SEARCH_PROVIDER").ok().as_deref())
|
||||
.unwrap_or_else(|e| panic!("{}", e));
|
||||
let tavily_api_key = std::env::var("TAVILY_API_KEY").ok();
|
||||
let brave_api_key = std::env::var("BRAVE_API_KEY").ok();
|
||||
|
||||
@ -181,7 +213,8 @@ async fn main() {
|
||||
.await
|
||||
.expect("Failed to connect to database");
|
||||
|
||||
// Run migrations
|
||||
// Run migrations in order. Each one is written so startup can safely try it
|
||||
// again and skip work that already happened in an earlier run.
|
||||
let migration_sql = include_str!("../migrations/001_init.sql");
|
||||
sqlx::raw_sql(migration_sql)
|
||||
.execute(&db)
|
||||
@ -282,6 +315,8 @@ async fn main() {
|
||||
|
||||
tracing::info!("Database initialized");
|
||||
|
||||
// WebSocket tasks subscribe to this channel to receive room events without
|
||||
// polling the database.
|
||||
let (tx, _rx) = broadcast::channel::<models::BroadcastEvent>(4096);
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
@ -302,32 +337,61 @@ async fn main() {
|
||||
// Serve static files from client dist in production
|
||||
let static_dir = std::env::var("STATIC_DIR").unwrap_or_else(|_| "../client/dist".into());
|
||||
|
||||
// Keep API routes separate from the static-file fallback so `/api/*` and
|
||||
// `/ws` requests never get mistaken for SPA routes.
|
||||
let api_routes = Router::new()
|
||||
// Auth routes
|
||||
.route("/api/auth/register", post(handlers::auth::register))
|
||||
.route("/api/auth/login", post(handlers::auth::login))
|
||||
.route("/api/auth/me", get(handlers::auth::me))
|
||||
// Nostr auth routes
|
||||
.route("/api/auth/nostr/challenge", get(handlers::nostr_auth::challenge))
|
||||
.route(
|
||||
"/api/auth/nostr/challenge",
|
||||
get(handlers::nostr_auth::challenge),
|
||||
)
|
||||
.route("/api/auth/nostr/verify", post(handlers::nostr_auth::verify))
|
||||
// Profile routes
|
||||
.route("/api/auth/profile", put(handlers::profile::update_profile))
|
||||
.route("/api/auth/avatar", post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar))
|
||||
.route(
|
||||
"/api/auth/avatar",
|
||||
post(handlers::profile::upload_avatar).delete(handlers::profile::delete_avatar),
|
||||
)
|
||||
// Room routes
|
||||
.route("/api/rooms", get(handlers::rooms::list_rooms).post(handlers::rooms::create_room))
|
||||
.route("/api/rooms/:room_id", get(handlers::rooms::get_room).delete(handlers::rooms::delete_room))
|
||||
.route("/api/rooms/:room_id/messages", get(handlers::rooms::get_messages))
|
||||
.route(
|
||||
"/api/rooms",
|
||||
get(handlers::rooms::list_rooms).post(handlers::rooms::create_room),
|
||||
)
|
||||
.route(
|
||||
"/api/rooms/:room_id",
|
||||
get(handlers::rooms::get_room).delete(handlers::rooms::delete_room),
|
||||
)
|
||||
.route(
|
||||
"/api/rooms/:room_id/messages",
|
||||
get(handlers::rooms::get_messages),
|
||||
)
|
||||
.route("/api/rooms/:room_id/join", post(handlers::rooms::join_room))
|
||||
.route("/api/rooms/:room_id/clear", post(handlers::rooms::clear_room))
|
||||
.route("/api/messages/hash/:hash", get(handlers::rooms::resolve_message_hash))
|
||||
.route(
|
||||
"/api/rooms/:room_id/clear",
|
||||
post(handlers::rooms::clear_room),
|
||||
)
|
||||
.route(
|
||||
"/api/messages/hash/:hash",
|
||||
get(handlers::rooms::resolve_message_hash),
|
||||
)
|
||||
// Upload (chat images)
|
||||
.route("/api/upload", post(handlers::upload::upload_chat_image))
|
||||
// Models
|
||||
.route("/api/models", get(handlers::models::list_models))
|
||||
// Invite routes
|
||||
.route("/api/invites", post(handlers::invites::create_invite))
|
||||
.route("/api/invites/:token/accept", post(handlers::invites::accept_invite))
|
||||
.route("/api/invites/nostr", post(handlers::invites::invite_by_nostr))
|
||||
.route(
|
||||
"/api/invites/:token/accept",
|
||||
post(handlers::invites::accept_invite),
|
||||
)
|
||||
.route(
|
||||
"/api/invites/nostr",
|
||||
post(handlers::invites::invite_by_nostr),
|
||||
)
|
||||
// Uploaded files (avatars)
|
||||
.nest_service("/uploads", ServeDir::new("uploads"))
|
||||
// WebSocket
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
http::request::Parts,
|
||||
};
|
||||
use axum::{extract::FromRequestParts, http::request::Parts};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{models::Claims, AppState};
|
||||
|
||||
/// Extract authenticated user from JWT in Authorization header
|
||||
/// Authenticated user information extracted from the bearer token.
|
||||
pub struct AuthUser {
|
||||
pub user_id: String,
|
||||
pub email: String,
|
||||
@ -19,7 +16,15 @@ pub struct AuthUser {
|
||||
impl FromRequestParts<Arc<AppState>> for AuthUser {
|
||||
type Rejection = axum::http::StatusCode;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &Arc<AppState>) -> Result<Self, Self::Rejection> {
|
||||
/// Read the `Authorization: Bearer <token>` header and decode the JWT.
|
||||
///
|
||||
/// Axum runs this automatically for any handler parameter of type
|
||||
/// `AuthUser`, which keeps individual handlers free from repeated token
|
||||
/// parsing logic.
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &Arc<AppState>,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
let auth_header = parts
|
||||
.headers
|
||||
.get("Authorization")
|
||||
@ -41,7 +46,16 @@ impl FromRequestParts<Arc<AppState>> for AuthUser {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str) -> Result<String, jsonwebtoken::errors::Error> {
|
||||
/// Create a signed JWT for a logged-in user.
|
||||
///
|
||||
/// The token expires after seven days and carries the small amount of identity
|
||||
/// data the server wants available on every request.
|
||||
pub fn create_token(
|
||||
user_id: &str,
|
||||
email: &str,
|
||||
display_name: &str,
|
||||
secret: &str,
|
||||
) -> Result<String, jsonwebtoken::errors::Error> {
|
||||
let expiration = chrono::Utc::now()
|
||||
.checked_add_signed(chrono::Duration::days(7))
|
||||
.unwrap()
|
||||
@ -61,6 +75,7 @@ pub fn create_token(user_id: &str, email: &str, display_name: &str, secret: &str
|
||||
)
|
||||
}
|
||||
|
||||
/// Decode and validate a previously issued JWT.
|
||||
pub fn decode_token(token: &str, secret: &str) -> Result<Claims, jsonwebtoken::errors::Error> {
|
||||
let token_data = decode::<Claims>(
|
||||
token,
|
||||
|
||||
@ -1 +1,3 @@
|
||||
//! Reusable request-processing layers shared across handlers.
|
||||
|
||||
pub mod auth;
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
//! Core data structures shared across the server.
|
||||
//!
|
||||
//! This file intentionally mixes database row types, HTTP payloads, WebSocket
|
||||
//! payloads, and a few helper functions so the rest of the codebase can import
|
||||
//! common shapes from one place.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// ── Database models ──
|
||||
|
||||
/// Row from the `users` table.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct User {
|
||||
pub id: String,
|
||||
@ -11,6 +18,7 @@ pub struct User {
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// Row from the `rooms` table.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct Room {
|
||||
pub id: String,
|
||||
@ -24,6 +32,7 @@ pub struct Room {
|
||||
pub deleted_at: Option<String>,
|
||||
}
|
||||
|
||||
/// Row from the `messages` table.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct Message {
|
||||
pub id: String,
|
||||
@ -38,6 +47,7 @@ pub struct Message {
|
||||
pub hash: Option<String>,
|
||||
}
|
||||
|
||||
/// Row from the `invites` table.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct Invite {
|
||||
pub id: String,
|
||||
@ -51,6 +61,7 @@ pub struct Invite {
|
||||
|
||||
// ── API request/response types ──
|
||||
|
||||
/// JSON body expected by the registration endpoint.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
pub email: String,
|
||||
@ -58,18 +69,21 @@ pub struct RegisterRequest {
|
||||
pub display_name: String,
|
||||
}
|
||||
|
||||
/// JSON body expected by the login endpoint.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
/// Standard auth response returned after login, registration, or profile update.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AuthResponse {
|
||||
pub token: String,
|
||||
pub user: UserPublic,
|
||||
}
|
||||
|
||||
/// Public user data safe to return to any authenticated client.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UserPublic {
|
||||
pub id: String,
|
||||
@ -81,6 +95,7 @@ pub struct UserPublic {
|
||||
pub nostr_pubkey: Option<String>,
|
||||
}
|
||||
|
||||
/// JSON body used when a user creates a new chat room.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateRoomRequest {
|
||||
pub name: String,
|
||||
@ -93,10 +108,10 @@ pub struct CreateRoomRequest {
|
||||
pub ai_name: String,
|
||||
}
|
||||
|
||||
/// Pick a friendly default AI display name when the creator does not specify one.
|
||||
fn default_ai_name() -> String {
|
||||
let names = [
|
||||
"Nova", "Atlas", "Sage", "Echo", "Pixel",
|
||||
"Cosmo", "Ember", "Flux", "Lyra", "Onyx",
|
||||
"Nova", "Atlas", "Sage", "Echo", "Pixel", "Cosmo", "Ember", "Flux", "Lyra", "Onyx",
|
||||
];
|
||||
let idx = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
@ -105,10 +120,12 @@ fn default_ai_name() -> String {
|
||||
names[idx].to_string()
|
||||
}
|
||||
|
||||
/// Default prompt that defines the AI assistant's behavior inside a room.
|
||||
fn default_system_prompt() -> String {
|
||||
"You are a helpful AI assistant participating in a group chat. Be conversational, helpful, and concise. You can see messages from all participants. When mentioned with @ai, respond helpfully.\n\nYou have access to tools:\n- **web_search**: Search the web for current information. Use this when asked about recent events, news, facts you're unsure about, or anything that needs up-to-date information.\n- **web_fetch**: Fetch and read the content of a web page. Use this when a user shares a URL and wants you to read/summarize it, or when you need more details from a search result.\n\nUse tools proactively when they would help answer the question better. You don't need to ask permission to use them.".to_string()
|
||||
}
|
||||
|
||||
/// Full room payload returned to the client, including current members.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RoomResponse {
|
||||
pub id: String,
|
||||
@ -122,6 +139,7 @@ pub struct RoomResponse {
|
||||
pub members: Vec<UserPublic>,
|
||||
}
|
||||
|
||||
/// JSON body for an email-based room invite.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateInviteRequest {
|
||||
pub room_id: String,
|
||||
@ -130,6 +148,7 @@ pub struct CreateInviteRequest {
|
||||
|
||||
// ── WebSocket event types ──
|
||||
|
||||
/// Messages the browser can send over the WebSocket connection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsClientMessage {
|
||||
@ -148,17 +167,14 @@ pub enum WsClientMessage {
|
||||
Typing { room_id: String },
|
||||
}
|
||||
|
||||
/// Messages the server can push to browsers over the WebSocket connection.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WsServerMessage {
|
||||
#[serde(rename = "new_message")]
|
||||
NewMessage {
|
||||
message: MessagePayload,
|
||||
},
|
||||
NewMessage { message: MessagePayload },
|
||||
#[serde(rename = "ai_typing")]
|
||||
AiTyping {
|
||||
room_id: String,
|
||||
},
|
||||
AiTyping { room_id: String },
|
||||
#[serde(rename = "user_typing")]
|
||||
UserTyping {
|
||||
room_id: String,
|
||||
@ -166,21 +182,13 @@ pub enum WsServerMessage {
|
||||
display_name: String,
|
||||
},
|
||||
#[serde(rename = "error")]
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
Error { message: String },
|
||||
#[serde(rename = "joined")]
|
||||
Joined {
|
||||
room_id: String,
|
||||
},
|
||||
Joined { room_id: String },
|
||||
#[serde(rename = "room_deleted")]
|
||||
RoomDeleted {
|
||||
room_id: String,
|
||||
},
|
||||
RoomDeleted { room_id: String },
|
||||
#[serde(rename = "room_cleared")]
|
||||
RoomCleared {
|
||||
room_id: String,
|
||||
},
|
||||
RoomCleared { room_id: String },
|
||||
#[serde(rename = "ai_tool_usage")]
|
||||
AiToolUsage {
|
||||
room_id: String,
|
||||
@ -194,12 +202,10 @@ pub enum WsServerMessage {
|
||||
delta: String,
|
||||
},
|
||||
#[serde(rename = "ai_stream_end")]
|
||||
AiStreamEnd {
|
||||
room_id: String,
|
||||
message_id: String,
|
||||
},
|
||||
AiStreamEnd { room_id: String, message_id: String },
|
||||
}
|
||||
|
||||
/// Message shape sent to clients for history loading and live updates.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MessagePayload {
|
||||
pub id: String,
|
||||
@ -224,7 +230,7 @@ pub struct MessagePayload {
|
||||
|
||||
/// Compute Gravatar-compatible MD5 hash from an email address.
|
||||
pub fn gravatar_hash(email: &str) -> String {
|
||||
use md5::{Md5, Digest};
|
||||
use md5::{Digest, Md5};
|
||||
let normalized = email.trim().to_lowercase();
|
||||
let result = Md5::digest(normalized.as_bytes());
|
||||
format!("{:x}", result)
|
||||
@ -232,13 +238,14 @@ pub fn gravatar_hash(email: &str) -> String {
|
||||
|
||||
/// Compute SHA-256 integrity hash from created_at timestamp + message content.
|
||||
pub fn message_hash(created_at: &str, content: &str) -> String {
|
||||
use sha2::{Sha256, Digest};
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(created_at.as_bytes());
|
||||
hasher.update(content.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Usage and tool metadata captured for AI-generated messages.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AiMeta {
|
||||
pub model: String,
|
||||
@ -250,6 +257,7 @@ pub struct AiMeta {
|
||||
pub tool_results: Option<Vec<ToolResult>>,
|
||||
}
|
||||
|
||||
/// One tool invocation performed while generating an AI answer.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResult {
|
||||
pub tool: String,
|
||||
@ -259,6 +267,7 @@ pub struct ToolResult {
|
||||
|
||||
// ── Broadcast event (internal channel) ──
|
||||
|
||||
/// Internal event sent through a Tokio broadcast channel to WebSocket tasks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BroadcastEvent {
|
||||
pub room_id: String,
|
||||
@ -267,9 +276,10 @@ pub struct BroadcastEvent {
|
||||
|
||||
// ── JWT Claims ──
|
||||
|
||||
/// Claims stored inside the server-issued JWT.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String, // user_id
|
||||
pub sub: String, // user_id
|
||||
pub email: String,
|
||||
pub display_name: String,
|
||||
pub exp: usize,
|
||||
@ -277,6 +287,7 @@ pub struct Claims {
|
||||
|
||||
// ── Pagination ──
|
||||
|
||||
/// Common pagination parameters for message history endpoints.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PaginationParams {
|
||||
#[serde(default = "default_limit")]
|
||||
@ -288,7 +299,7 @@ fn default_limit() -> i64 {
|
||||
50
|
||||
}
|
||||
|
||||
/// Returns "" if the email is a sentinel nostr: value, otherwise returns it as-is.
|
||||
/// Hide placeholder `nostr:*` emails from normal client responses.
|
||||
pub fn public_email(email: &str) -> String {
|
||||
if email.starts_with("nostr:") {
|
||||
String::new()
|
||||
@ -299,11 +310,13 @@ pub fn public_email(email: &str) -> String {
|
||||
|
||||
// ── Nostr auth types ──
|
||||
|
||||
/// Response returned by the Nostr challenge endpoint.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct NostrChallengeResponse {
|
||||
pub challenge: String,
|
||||
}
|
||||
|
||||
/// JSON body sent by the client when proving Nostr ownership.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct NostrVerifyRequest {
|
||||
pub signed_event: String,
|
||||
@ -312,6 +325,7 @@ pub struct NostrVerifyRequest {
|
||||
pub profile_picture: Option<String>,
|
||||
}
|
||||
|
||||
/// JSON body for inviting an already-known Nostr user into a room.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct NostrInviteRequest {
|
||||
pub room_id: String,
|
||||
|
||||
@ -4,6 +4,7 @@ use crate::services::search::SearchResult;
|
||||
|
||||
const BRAVE_SEARCH_URL: &str = "https://api.search.brave.com/res/v1/web/search";
|
||||
|
||||
/// Partial Brave API response containing only the fields this app needs.
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BraveResponse {
|
||||
web: Option<BraveWebResults>,
|
||||
@ -27,11 +28,7 @@ struct BraveResult {
|
||||
|
||||
/// Search the web using the Brave Search API.
|
||||
/// Returns a list of simplified search results.
|
||||
pub async fn search(
|
||||
query: &str,
|
||||
api_key: &str,
|
||||
count: u8,
|
||||
) -> Result<Vec<SearchResult>, String> {
|
||||
pub async fn search(query: &str, api_key: &str, count: u8) -> Result<Vec<SearchResult>, String> {
|
||||
let count = count.clamp(1, 10);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
@ -19,9 +19,29 @@ const STRIP_TAGS: &[&str] = &[
|
||||
|
||||
/// Block-level tags that should produce newlines in text output.
|
||||
const BLOCK_TAGS: &[&str] = &[
|
||||
"p", "div", "h1", "h2", "h3", "h4", "h5", "h6", "li", "br", "tr",
|
||||
"blockquote", "pre", "section", "article", "main", "header",
|
||||
"dt", "dd", "figcaption", "table", "thead", "tbody",
|
||||
"p",
|
||||
"div",
|
||||
"h1",
|
||||
"h2",
|
||||
"h3",
|
||||
"h4",
|
||||
"h5",
|
||||
"h6",
|
||||
"li",
|
||||
"br",
|
||||
"tr",
|
||||
"blockquote",
|
||||
"pre",
|
||||
"section",
|
||||
"article",
|
||||
"main",
|
||||
"header",
|
||||
"dt",
|
||||
"dd",
|
||||
"figcaption",
|
||||
"table",
|
||||
"thead",
|
||||
"tbody",
|
||||
];
|
||||
|
||||
/// Fetch a URL and extract its text content.
|
||||
|
||||
@ -1,3 +1,9 @@
|
||||
//! Integrations with external systems used by the chat server.
|
||||
//!
|
||||
//! These modules wrap search providers, web page fetching, and the OpenRouter
|
||||
//! chat completion API so the rest of the application can call them with simple
|
||||
//! Rust types.
|
||||
|
||||
pub mod brave;
|
||||
pub mod fetch;
|
||||
pub mod openrouter;
|
||||
|
||||
@ -235,7 +235,9 @@ pub async fn chat_completion_stream(
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let _ = tx.send(StreamEvent::Error(format!("Request failed: {}", e))).await;
|
||||
let _ = tx
|
||||
.send(StreamEvent::Error(format!("Request failed: {}", e)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
@ -243,7 +245,12 @@ pub async fn chat_completion_stream(
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let _ = tx.send(StreamEvent::Error(format!("OpenRouter error {}: {}", status, body))).await;
|
||||
let _ = tx
|
||||
.send(StreamEvent::Error(format!(
|
||||
"OpenRouter error {}: {}",
|
||||
status, body
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
@ -264,7 +271,9 @@ pub async fn chat_completion_stream(
|
||||
let bytes = match chunk_result {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
let _ = tx.send(StreamEvent::Error(format!("Stream error: {}", e))).await;
|
||||
let _ = tx
|
||||
.send(StreamEvent::Error(format!("Stream error: {}", e)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
@ -338,7 +347,10 @@ pub async fn chat_completion_stream(
|
||||
tool_call_accum[idx].function.name.push_str(name);
|
||||
}
|
||||
if let Some(args) = &func.arguments {
|
||||
tool_call_accum[idx].function.arguments.push_str(args);
|
||||
tool_call_accum[idx]
|
||||
.function
|
||||
.arguments
|
||||
.push_str(args);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -373,7 +385,11 @@ pub async fn chat_completion_stream(
|
||||
// AI requested tool calls
|
||||
let assistant_msg = ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: if full_content.is_empty() { None } else { Some(Content::Text(full_content)) },
|
||||
content: if full_content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Content::Text(full_content))
|
||||
},
|
||||
tool_calls: Some(tool_call_accum),
|
||||
tool_call_id: None,
|
||||
};
|
||||
@ -420,7 +436,9 @@ pub fn build_chat_history(
|
||||
Content::Parts(vec![
|
||||
ContentPart::Text { text },
|
||||
ContentPart::ImageUrl {
|
||||
image_url: ImageUrlData { url: data_url.clone() },
|
||||
image_url: ImageUrlData {
|
||||
url: data_url.clone(),
|
||||
},
|
||||
},
|
||||
])
|
||||
} else {
|
||||
|
||||
@ -2,6 +2,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{brave, tavily};
|
||||
|
||||
/// Which search backend the AI tool layer should call.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SearchProvider {
|
||||
Tavily,
|
||||
@ -9,8 +10,14 @@ pub enum SearchProvider {
|
||||
}
|
||||
|
||||
impl SearchProvider {
|
||||
/// Parse the `SEARCH_PROVIDER` environment variable into a supported variant.
|
||||
pub fn from_env(value: Option<&str>) -> Result<Self, String> {
|
||||
match value.unwrap_or("tavily").trim().to_ascii_lowercase().as_str() {
|
||||
match value
|
||||
.unwrap_or("tavily")
|
||||
.trim()
|
||||
.to_ascii_lowercase()
|
||||
.as_str()
|
||||
{
|
||||
"tavily" => Ok(Self::Tavily),
|
||||
"brave" => Ok(Self::Brave),
|
||||
other => Err(format!(
|
||||
@ -20,6 +27,7 @@ impl SearchProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the environment variable name required by the selected provider.
|
||||
pub fn required_key_name(self) -> &'static str {
|
||||
match self {
|
||||
Self::Tavily => "TAVILY_API_KEY",
|
||||
@ -28,6 +36,7 @@ impl SearchProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalized search result shape shared across providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SearchResult {
|
||||
pub title: String,
|
||||
@ -36,6 +45,7 @@ pub struct SearchResult {
|
||||
pub age: Option<String>,
|
||||
}
|
||||
|
||||
/// Dispatch a search request to whichever provider the server is configured to use.
|
||||
pub async fn search(
|
||||
provider: SearchProvider,
|
||||
query: &str,
|
||||
@ -59,6 +69,7 @@ pub async fn search(
|
||||
}
|
||||
}
|
||||
|
||||
/// Turn search results into plain text the AI model can read as tool output.
|
||||
pub fn format_results(results: &[SearchResult]) -> String {
|
||||
if results.is_empty() {
|
||||
return "No search results found.".to_string();
|
||||
|
||||
@ -23,11 +23,7 @@ struct TavilyResult {
|
||||
published_date: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn search(
|
||||
query: &str,
|
||||
api_key: &str,
|
||||
count: u8,
|
||||
) -> Result<Vec<SearchResult>, String> {
|
||||
pub async fn search(query: &str, api_key: &str, count: u8) -> Result<Vec<SearchResult>, String> {
|
||||
let max_results = count.clamp(1, 10);
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
@ -75,7 +71,10 @@ pub async fn search(
|
||||
if first.description.is_empty() {
|
||||
first.description = format!("AI summary: {}", answer);
|
||||
} else {
|
||||
first.description = format!("AI summary: {}\nSource excerpt: {}", answer, first.description);
|
||||
first.description = format!(
|
||||
"AI summary: {}\nSource excerpt: {}",
|
||||
answer, first.description
|
||||
);
|
||||
}
|
||||
} else {
|
||||
results.push(SearchResult {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user