diff --git a/server/src/ws/handler.rs b/server/src/ws/handler.rs index f4682d8..b5d5fd6 100644 --- a/server/src/ws/handler.rs +++ b/server/src/ws/handler.rs @@ -4,6 +4,7 @@ use std::time::{Duration, Instant}; use actix_ws::{Message, Session}; use futures::StreamExt; use log::{info, warn}; +use tokio::sync::mpsc; use crate::models::{ClientType, WsMessage}; @@ -43,9 +44,8 @@ pub async fn ws_index( .to_string(); // Perform the WebSocket upgrade. - let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?; + let (response, session, msg_stream) = actix_ws::handle(&req, body)?; - // Log the connection. info!( "[ws] {} connected to session {} (ip={})", match client_type { @@ -56,7 +56,21 @@ pub async fn ws_index( ip ); - // Register agent in shared state. + // ── Per-connection setup based on client type ────────────────────────── + + // Create an mpsc channel for outgoing messages to this client. + // Viewers get a larger buffer since they receive frames (high throughput). + // Agents get a smaller buffer since they only receive HUD commands (low frequency). + let channel_capacity = match client_type { + ClientType::Viewer => 120, + ClientType::Agent => 64, + }; + let (tx, mut rx) = mpsc::channel::(channel_capacity); + + // Track IDs for cleanup. + let mut agent_id_for_cleanup: Option = None; + let mut viewer_id_for_cleanup: Option = None; + if client_type == ClientType::Agent { let agent_id = uuid::Uuid::new_v4().to_string(); let agent = crate::models::AgentConnection { @@ -68,63 +82,121 @@ pub async fn ws_index( audio_active: false, }; state.register_agent(agent); + state.register_agent_channel(&session_id, tx); + agent_id_for_cleanup = Some(agent_id); + } else { + let viewer_id = uuid::Uuid::new_v4().to_string(); + state.register_viewer(&session_id, &viewer_id, tx); + + // Send the latest buffered frame to the new viewer immediately, + // so they see something without waiting for the next frame from the agent. + if let Some(frame_data) = state.get_latest_frame(&session_id) { + let msg = WsMessage::FrameBroadcast { + data: frame_data, + content_type: "image/jpeg".into(), + }; + let json = serde_json::to_string(&msg).unwrap_or_default(); + // Best-effort immediate send on the channel (non-blocking). + let _ = tx.try_send(json); + } + + viewer_id_for_cleanup = Some(viewer_id); } - // Clone references for the spawned task. + // ── Spawn the connection task ────────────────────────────────────────── + let state_clone = state.clone(); let session_id_clone = session_id.clone(); let client_type_clone = client_type.clone(); + let timeout = Duration::from_secs(state.idle_timeout_secs); - // Spawn the reader task – receives messages FROM the client. actix_web::rt::spawn(async move { - let mut last_heartbeat = Instant::now(); - let timeout = Duration::from_secs(state_clone.idle_timeout_secs); + let mut last_activity = Instant::now(); + let mut timeout_sleep = tokio::time::sleep(timeout); - while let Some(Ok(msg)) = msg_stream.next().await { - match msg { - Message::Text(text) => { - last_heartbeat = Instant::now(); - handle_text_message( - &text, - &session_id_clone, - &client_type_clone, - &state_clone, - &mut session, - ) - .await; + loop { + tokio::select! { + // ── Incoming WebSocket message from the client ───────────── + ws_msg = msg_stream.next() => { + last_activity = Instant::now(); + timeout_sleep = tokio::time::sleep(timeout); + + match ws_msg { + Some(Ok(Message::Text(text))) => { + handle_text_message( + &text, + &session_id_clone, + &client_type_clone, + &state_clone, + &session, + ) + .await; + } + Some(Ok(Message::Ping(bytes))) => { + let _ = session.pong(&bytes).await; + } + Some(Ok(Message::Close(reason))) => { + info!( + "[ws] {:?} disconnected from session {} (close: {:?})", + client_type_clone, session_id_clone, reason + ); + break; + } + Some(Ok(Message::Binary(_))) => { + // We only use text (JSON) messages. Ignore binary. + } + Some(Err(e)) => { + warn!("[ws] read error for {:?} on session {}: {}", client_type_clone, session_id_clone, e); + break; + } + None => { + // Stream ended (client disconnected). + break; + } + } } - Message::Ping(bytes) => { - last_heartbeat = Instant::now(); - let _ = session.pong(&bytes).await; + + // ── Outgoing message from the broadcast/forward channel ──── + out_msg = rx.recv() => { + last_activity = Instant::now(); + timeout_sleep = tokio::time::sleep(timeout); + + match out_msg { + Some(text) => { + if session.text(&text).await.is_err() { + warn!("[ws] write failed for {:?} on session {}", client_type_clone, session_id_clone); + break; + } + } + None => { + // Channel closed (sender dropped during cleanup). + break; + } + } } - Message::Close(reason) => { - info!( - "[ws] {:?} disconnected from session {} (close: {:?})", - client_type_clone, session_id_clone, reason + + // ── Idle timeout ─────────────────────────────────────────── + _ = &mut timeout_sleep => { + warn!( + "[ws] {:?} timed out on session {} ({}s idle)", + client_type_clone, session_id_clone, timeout.as_secs() ); - cleanup(&state_clone, &session_id_clone, &client_type_clone); - return; + break; } - _ => {} - } - - // Check for idle timeout. - if last_heartbeat.elapsed() > timeout { - warn!( - "[ws] {:?} timed out on session {}", - client_type_clone, session_id_clone - ); - cleanup(&state_clone, &session_id_clone, &client_type_clone); - let _ = session.close(None).await; - return; } } - info!( - "[ws] {:?} stream ended for session {}", - client_type_clone, session_id_clone + // ── Cleanup ──────────────────────────────────────────────────────── + cleanup_connection( + &state_clone, + &session_id_clone, + &client_type_clone, + &agent_id_for_cleanup, + &viewer_id_for_cleanup, ); - cleanup(&state_clone, &session_id_clone, &client_type_clone); + + // Best-effort close the WebSocket. + let _ = session.close(None).await; }); Ok(response) @@ -138,16 +210,23 @@ async fn handle_text_message( session_id: &str, client_type: &ClientType, state: &Arc, - ws_session: &mut Session, + ws_session: &Session, ) { let msg: WsMessage = match serde_json::from_str(raw) { Ok(m) => m, Err(e) => { - warn!("[ws] invalid message: {} ({})", e, raw.chars().take(120).collect::()); + warn!( + "[ws] invalid message: {} ({})", + e, + raw.chars().take(120).collect::() + ); let _ = ws_session - .text(serde_json::to_string(&WsMessage::Error { - message: format!("invalid message: {}", e), - }).unwrap_or_default()) + .text( + serde_json::to_string(&WsMessage::Error { + message: format!("invalid message: {}", e), + }) + .unwrap_or_default(), + ) .await; return; } @@ -157,117 +236,133 @@ async fn handle_text_message( // ── From Agent ──────────────────────────────────────────────────── WsMessage::DisplayFrame { data, .. } if *client_type == ClientType::Agent => { state.push_frame(session_id, data.clone()); - broadcast_to_viewers(state, session_id, &WsMessage::FrameBroadcast { + + let broadcast = WsMessage::FrameBroadcast { data, content_type: "image/jpeg".into(), - }).await; + }; + let json = serde_json::to_string(&broadcast).unwrap_or_default(); + state.broadcast_to_viewers(session_id, &json).await; } WsMessage::AudioFrame { data, .. } if *client_type == ClientType::Agent => { - broadcast_to_viewers(state, session_id, &WsMessage::AudioBroadcast { + let broadcast = WsMessage::AudioBroadcast { data, content_type: "audio/opus".into(), - }).await; + }; + let json = serde_json::to_string(&broadcast).unwrap_or_default(); + state.broadcast_to_viewers(session_id, &json).await; } - WsMessage::AgentInfo { agent_id, resolution, .. } if *client_type == ClientType::Agent => { + WsMessage::AgentInfo { + agent_id, + resolution, + .. + } if *client_type == ClientType::Agent => { state.activate_session(session_id, resolution.as_deref()); if let Some(session) = state.get_session(session_id) { - broadcast_to_viewers(state, session_id, &WsMessage::SessionUpdate { + let update = WsMessage::SessionUpdate { session_id: session_id.to_string(), status: session.status, resolution: session.resolution, - }).await; + }; + let json = serde_json::to_string(&update).unwrap_or_default(); + state.broadcast_to_viewers(session_id, &json).await; } - info!("[ws] agent {} reported for session {}", agent_id, session_id); - let _ = ws_session.text(serde_json::to_string(&WsMessage::Ack { - message: "agent registered".into(), - }).unwrap_or_default()).await; + info!( + "[ws] agent {} reported for session {}", + agent_id, session_id + ); + let _ = ws_session + .text( + serde_json::to_string(&WsMessage::Ack { + message: "agent registered".into(), + }) + .unwrap_or_default(), + ) + .await; } WsMessage::Heartbeat if *client_type == ClientType::Agent => { - // Keepalive — nothing to do. + // Keepalive — nothing to do, the timeout is reset by receiving any message. } // ── From Viewer ─────────────────────────────────────────────────── - WsMessage::HudCommand { command, params, .. } if *client_type == ClientType::Viewer => { - forward_to_agent(state, session_id, &WsMessage::ForwardHudCommand { command, params }).await; + WsMessage::HudCommand { + command, params, .. + } if *client_type == ClientType::Viewer => { + let forward = WsMessage::ForwardHudCommand { command, params }; + let json = serde_json::to_string(&forward).unwrap_or_default(); + if !state.send_to_agent(session_id, &json).await { + let _ = ws_session + .text( + serde_json::to_string(&WsMessage::Error { + message: "no agent connected for this session".into(), + }) + .unwrap_or_default(), + ) + .await; + } } - WsMessage::Resize { width, height, .. } if *client_type == ClientType::Viewer => { - forward_to_agent(state, session_id, &WsMessage::ForwardResize { width, height }).await; + WsMessage::Resize { + width, height, .. + } if *client_type == ClientType::Viewer => { + let forward = WsMessage::ForwardResize { width, height }; + let json = serde_json::to_string(&forward).unwrap_or_default(); + if !state.send_to_agent(session_id, &json).await { + let _ = ws_session + .text( + serde_json::to_string(&WsMessage::Error { + message: "no agent connected for this session".into(), + }) + .unwrap_or_default(), + ) + .await; + } + } + + // ── Heartbeat from Viewer ───────────────────────────────────────── + WsMessage::Heartbeat if *client_type == ClientType::Viewer => { + let _ = ws_session + .text( + serde_json::to_string(&WsMessage::Ack { + message: "heartbeat".into(), + }) + .unwrap_or_default(), + ) + .await; } // ── Fallback ────────────────────────────────────────────────────── _ => { - warn!("[ws] unexpected message type from {:?} for session {}", client_type, session_id); + warn!( + "[ws] unexpected message type from {:?} for session {}", + client_type, session_id + ); } } } -/// Broadcast a message to all **viewer** WebSocket connections for a session. -/// -/// TODO: maintain a registry of per-viewer Session senders in AppState. -/// For now we log and rely on the frame buffer for new viewers. -async fn broadcast_to_viewers( - _state: &Arc, - session_id: &str, - msg: &WsMessage, -) { - log::debug!("[ws] broadcast to viewers of session {}: {:?}", session_id, msg.msg_type_debug()); -} - -/// Forward a message to the agent connected to a session. -/// -/// TODO: maintain per-agent mpsc channel in AppState. -async fn forward_to_agent( - _state: &Arc, - session_id: &str, - msg: &WsMessage, -) { - log::info!( - "[ws] forward to agent in session {}: {:?}", - session_id, - msg.msg_type_debug() - ); -} - /// Clean up when a client disconnects. -fn cleanup(state: &Arc, session_id: &str, client_type: &ClientType) { - if *client_type == ClientType::Agent { - let agent_id = state - .agents - .iter() - .find(|r| r.session_id == session_id) - .map(|r| r.agent_id.clone()); - if let Some(aid) = agent_id { - state.unregister_agent(&aid); - } - } -} - -/// Helper trait to get a human-readable tag for logging. -trait MsgTypeDebug { - fn msg_type_debug(&self) -> &'static str; -} - -impl MsgTypeDebug for WsMessage { - fn msg_type_debug(&self) -> &'static str { - match self { - WsMessage::DisplayFrame { .. } => "display_frame", - WsMessage::AudioFrame { .. } => "audio_frame", - WsMessage::AgentInfo { .. } => "agent_info", - WsMessage::Heartbeat => "heartbeat", - WsMessage::HudCommand { .. } => "hud_command", - WsMessage::Resize { .. } => "resize", - WsMessage::FrameBroadcast { .. } => "frame_broadcast", - WsMessage::AudioBroadcast { .. } => "audio_broadcast", - WsMessage::SessionUpdate { .. } => "session_update", - WsMessage::ForwardHudCommand { .. } => "forward_hud_command", - WsMessage::ForwardResize { .. } => "forward_resize", - WsMessage::StreamControl { .. } => "stream_control", - WsMessage::Error { .. } => "error", - WsMessage::Ack { .. } => "ack", +fn cleanup_connection( + state: &Arc, + session_id: &str, + client_type: &ClientType, + agent_id: &Option, + viewer_id: &Option, +) { + match client_type { + ClientType::Agent => { + if let Some(aid) = agent_id { + state.unregister_agent(aid); + } + state.unregister_agent_channel(session_id); + } + ClientType::Viewer => { + if let Some(vid) = viewer_id { + state.unregister_viewer(session_id, vid); + } } } }