diff --git a/server/src/api/sessions.rs b/server/src/api/sessions.rs index 275daa2..65797c7 100644 --- a/server/src/api/sessions.rs +++ b/server/src/api/sessions.rs @@ -2,8 +2,8 @@ use actix_web::{web, HttpResponse}; use serde::Deserialize; use std::sync::Arc; -use crate::models::{ApiResponse, Session, WsMessage}; -use crate::state::AppState; +use crate::models::{ApiResponse, Session}; +use crate::state::{AppState, WsOutMessage}; /// `GET /api/sessions` — list every session. pub async fn list_sessions(state: web::Data>) -> HttpResponse { @@ -64,11 +64,11 @@ pub async fn send_hud_command( let session_id = path.into_inner(); // Build the forward message. - let msg = WsMessage::ForwardHudCommand { - command: body.command.clone(), - params: body.params.clone(), - }; - let json = match serde_json::to_string(&msg) { + let json = match serde_json::to_string(&serde_json::json!({ + "msg_type": "forward_hud_command", + "command": body.command, + "params": body.params, + })) { Ok(j) => j, Err(e) => { return HttpResponse::InternalServerError() @@ -77,7 +77,7 @@ pub async fn send_hud_command( }; // Send through the agent channel. - if state.send_to_agent(&session_id, &json).await { + if state.send_to_agent(&session_id, WsOutMessage::Text(json)).await { HttpResponse::Ok().json(ApiResponse::ok("command forwarded")) } else { HttpResponse::Conflict() diff --git a/server/src/ws/handler.rs b/server/src/ws/handler.rs index b5d5fd6..bd6a42d 100644 --- a/server/src/ws/handler.rs +++ b/server/src/ws/handler.rs @@ -6,14 +6,12 @@ use futures::StreamExt; use log::{info, warn}; use tokio::sync::mpsc; -use crate::models::{ClientType, WsMessage}; +use crate::state::{WsOutMessage, MAX_BINARY_FRAME_SIZE}; /// ACTIX-WEB HTTP HANDLER /// -/// Upgrades the HTTP connection to a WebSocket and spawns an async task -/// that reads frames from the client and dispatches them. -/// -/// The query parameter `client_type` must be `"viewer"` or `"agent"`. +/// Upgrades HTTP to WebSocket. Handles both binary frames (video relay) and +/// text frames (JSON control messages). pub async fn ws_index( req: actix_web::HttpRequest, body: actix_web::web::Payload, @@ -22,19 +20,17 @@ pub async fn ws_index( ) -> Result { let session_id = path.into_inner(); - // Validate that the session exists. if !state.sessions.contains_key(&session_id) { return Ok(actix_web::HttpResponse::NotFound().json( crate::models::ApiResponse::<()>::err("session not found"), )); } - // Determine client type from query string. let query_str = req.query_string(); let client_type = if query_str.contains("client_type=agent") { - ClientType::Agent + "agent" } else { - ClientType::Viewer + "viewer" }; let ip = req @@ -43,71 +39,49 @@ pub async fn ws_index( .unwrap_or("unknown") .to_string(); - // Perform the WebSocket upgrade. let (response, session, msg_stream) = actix_ws::handle(&req, body)?; - info!( - "[ws] {} connected to session {} (ip={})", - match client_type { - ClientType::Agent => "AGENT", - ClientType::Viewer => "VIEWER", - }, - session_id, - ip - ); + info!("[ws] {} connected to session {} (ip={})", client_type, session_id, ip); - // ── Per-connection setup based on client type ────────────────────────── - - // Create an mpsc channel for outgoing messages to this client. - // Viewers get a larger buffer since they receive frames (high throughput). - // Agents get a smaller buffer since they only receive HUD commands (low frequency). - let channel_capacity = match client_type { - ClientType::Viewer => 120, - ClientType::Agent => 64, + // Create mpsc channel for outgoing messages to this client. + let channel_cap = match client_type { + "agent" => 64, + _ => 120, }; - let (tx, mut rx) = mpsc::channel::(channel_capacity); + let (tx, mut rx) = mpsc::channel::(channel_cap); // Track IDs for cleanup. - let mut agent_id_for_cleanup: Option = None; - let mut viewer_id_for_cleanup: Option = None; + let mut agent_id: Option = None; + let mut viewer_id: Option = None; - if client_type == ClientType::Agent { - let agent_id = uuid::Uuid::new_v4().to_string(); + if client_type == "agent" { + let aid = uuid::Uuid::new_v4().to_string(); let agent = crate::models::AgentConnection { - agent_id: agent_id.clone(), + agent_id: aid.clone(), session_id: session_id.clone(), connected_at: chrono::Utc::now(), - ip_address: ip.clone(), + ip_address: ip, display_active: false, audio_active: false, }; state.register_agent(agent); state.register_agent_channel(&session_id, tx); - agent_id_for_cleanup = Some(agent_id); + agent_id = Some(aid); } else { - let viewer_id = uuid::Uuid::new_v4().to_string(); - state.register_viewer(&session_id, &viewer_id, tx); + let vid = uuid::Uuid::new_v4().to_string(); + state.register_viewer(&session_id, &vid, tx); - // Send the latest buffered frame to the new viewer immediately, - // so they see something without waiting for the next frame from the agent. + // Send the latest buffered frame to the new viewer immediately. if let Some(frame_data) = state.get_latest_frame(&session_id) { - let msg = WsMessage::FrameBroadcast { - data: frame_data, - content_type: "image/jpeg".into(), - }; - let json = serde_json::to_string(&msg).unwrap_or_default(); - // Best-effort immediate send on the channel (non-blocking). - let _ = tx.try_send(json); + let _ = tx.try_send(WsOutMessage::Binary(frame_data)); } - - viewer_id_for_cleanup = Some(viewer_id); + viewer_id = Some(vid); } - // ── Spawn the connection task ────────────────────────────────────────── - + // Spawn the connection task. let state_clone = state.clone(); let session_id_clone = session_id.clone(); - let client_type_clone = client_type.clone(); + let client_type_clone = client_type.to_string(); let timeout = Duration::from_secs(state.idle_timeout_secs); actix_web::rt::spawn(async move { @@ -116,253 +90,202 @@ pub async fn ws_index( loop { tokio::select! { - // ── Incoming WebSocket message from the client ───────────── + // ── Incoming WebSocket message ────────────────────────── ws_msg = msg_stream.next() => { last_activity = Instant::now(); timeout_sleep = tokio::time::sleep(timeout); match ws_msg { Some(Ok(Message::Text(text))) => { - handle_text_message( - &text, - &session_id_clone, - &client_type_clone, - &state_clone, - &session, - ) - .await; + handle_text_message(&text, &session_id_clone, &client_type_clone, &state_clone, &session).await; + } + Some(Ok(Message::Binary(data))) => { + handle_binary_message(&data, &session_id_clone, &client_type_clone, &state_clone).await; } Some(Ok(Message::Ping(bytes))) => { let _ = session.pong(&bytes).await; } Some(Ok(Message::Close(reason))) => { - info!( - "[ws] {:?} disconnected from session {} (close: {:?})", - client_type_clone, session_id_clone, reason - ); + info!("[ws] {} disconnected from session {} ({:?})", client_type_clone, session_id_clone, reason); break; } - Some(Ok(Message::Binary(_))) => { - // We only use text (JSON) messages. Ignore binary. - } + Some(Ok(Message::Frame(_))) => {} Some(Err(e)) => { - warn!("[ws] read error for {:?} on session {}: {}", client_type_clone, session_id_clone, e); - break; - } - None => { - // Stream ended (client disconnected). + warn!("[ws] read error {} on session {}: {}", client_type_clone, session_id_clone, e); break; } + None => { break; } } } - // ── Outgoing message from the broadcast/forward channel ──── + // ── Outgoing message from broadcast/forward channel ──── out_msg = rx.recv() => { last_activity = Instant::now(); timeout_sleep = tokio::time::sleep(timeout); match out_msg { - Some(text) => { - if session.text(&text).await.is_err() { - warn!("[ws] write failed for {:?} on session {}", client_type_clone, session_id_clone); + Some(WsOutMessage::Binary(data)) => { + if session.binary(data).await.is_err() { break; } } - None => { - // Channel closed (sender dropped during cleanup). - break; + Some(WsOutMessage::Text(text)) => { + if session.text(&text).await.is_err() { + break; + } } + None => { break; } } } - // ── Idle timeout ─────────────────────────────────────────── + // ── Idle timeout ─────────────────────────────────────── _ = &mut timeout_sleep => { - warn!( - "[ws] {:?} timed out on session {} ({}s idle)", - client_type_clone, session_id_clone, timeout.as_secs() - ); + warn!("[ws] {} timed out on session {}", client_type_clone, session_id_clone); break; } } } - // ── Cleanup ──────────────────────────────────────────────────────── - cleanup_connection( - &state_clone, - &session_id_clone, - &client_type_clone, - &agent_id_for_cleanup, - &viewer_id_for_cleanup, - ); - - // Best-effort close the WebSocket. + // Cleanup. + match client_type_clone.as_str() { + "agent" => { + if let Some(aid) = agent_id { + state_clone.unregister_agent(&aid); + } + state_clone.unregister_agent_channel(&session_id_clone); + } + _ => { + if let Some(vid) = viewer_id { + state_clone.unregister_viewer(&session_id_clone, &vid); + } + } + } let _ = session.close(None).await; }); Ok(response) } -// ── Internal helpers ───────────────────────────────────────────────────────── +// ── Binary frame handler (video relay) ──────────────────────────────────────── -/// Parse and dispatch an incoming text (JSON) WebSocket message. +/// Handle a binary WebSocket frame. +/// +/// From agent: Raw video frame → push to buffer + broadcast to all viewers. +/// From viewer: Ignored (viewers shouldn't send binary frames). +async fn handle_binary_message( + data: &[u8], + session_id: &str, + client_type: &str, + state: &Arc, +) { + if client_type != "agent" { + return; + } + + // Reject oversized frames. + if data.len() > MAX_BINARY_FRAME_SIZE { + warn!("[ws] oversized binary frame ({} bytes) from agent in session {}", data.len(), session_id); + return; + } + + // Store in ring buffer (for late-joining viewers). + state.push_frame(session_id, data.to_vec()); + + // Broadcast to all viewers (zero-copy-ish: data is cloned per viewer). + state.broadcast_binary_frame(session_id, data.to_vec()).await; +} + +// ── Text message handler (JSON control) ────────────────────────────────────── + +/// Handle a text (JSON) WebSocket message. async fn handle_text_message( raw: &str, session_id: &str, - client_type: &ClientType, + client_type: &str, state: &Arc, ws_session: &Session, ) { - let msg: WsMessage = match serde_json::from_str(raw) { - Ok(m) => m, + // Try to parse as JSON control message. + let v: serde_json::Value = match serde_json::from_str(raw) { + Ok(v) => v, Err(e) => { - warn!( - "[ws] invalid message: {} ({})", - e, - raw.chars().take(120).collect::() - ); - let _ = ws_session - .text( - serde_json::to_string(&WsMessage::Error { - message: format!("invalid message: {}", e), - }) - .unwrap_or_default(), - ) - .await; + warn!("[ws] invalid JSON from {}: {} ({})", client_type, e, &raw[..raw.len().min(100)]); + let _ = ws_session.text(serde_json::to_string(&serde_json::json!({ + "msg_type": "error", + "message": format!("invalid JSON: {}", e) + })).unwrap_or_default()).await; return; } }; - match msg { - // ── From Agent ──────────────────────────────────────────────────── - WsMessage::DisplayFrame { data, .. } if *client_type == ClientType::Agent => { - state.push_frame(session_id, data.clone()); + let msg_type = v["msg_type"].as_str().unwrap_or(""); - let broadcast = WsMessage::FrameBroadcast { - data, - content_type: "image/jpeg".into(), - }; - let json = serde_json::to_string(&broadcast).unwrap_or_default(); - state.broadcast_to_viewers(session_id, &json).await; + match (client_type, msg_type) { + // ── Agent messages ───────────────────────────────────────────── + ("agent", "agent_info") => { + let resolution = v["resolution"].as_str(); + state.activate_session(session_id, resolution); + + // Notify viewers. + let update = serde_json::json!({ + "msg_type": "session_update", + "session_id": session_id, + "status": "active", + "resolution": resolution, + }); + state.broadcast_text(session_id, serde_json::to_string(&update).unwrap_or_default()).await; + + let _ = ws_session.text(serde_json::to_string(&serde_json::json!({ + "msg_type": "ack", + "message": "agent registered" + })).unwrap_or_default()).await; + + info!("[ws] agent registered for session {} (encoder: {})", session_id, v["encoder"].as_str().unwrap_or("unknown")); + } + ("agent", "heartbeat") => { + // Keepalive — timer reset at outer level. } - WsMessage::AudioFrame { data, .. } if *client_type == ClientType::Agent => { - let broadcast = WsMessage::AudioBroadcast { - data, - content_type: "audio/opus".into(), - }; - let json = serde_json::to_string(&broadcast).unwrap_or_default(); - state.broadcast_to_viewers(session_id, &json).await; - } + // ── Viewer messages ──────────────────────────────────────────── + ("viewer", "hud_command") => { + let command = v["command"].as_str().unwrap_or(""); + let params = &v["params"]; - WsMessage::AgentInfo { - agent_id, - resolution, - .. - } if *client_type == ClientType::Agent => { - state.activate_session(session_id, resolution.as_deref()); - if let Some(session) = state.get_session(session_id) { - let update = WsMessage::SessionUpdate { - session_id: session_id.to_string(), - status: session.status, - resolution: session.resolution, - }; - let json = serde_json::to_string(&update).unwrap_or_default(); - state.broadcast_to_viewers(session_id, &json).await; - } - info!( - "[ws] agent {} reported for session {}", - agent_id, session_id - ); - let _ = ws_session - .text( - serde_json::to_string(&WsMessage::Ack { - message: "agent registered".into(), - }) - .unwrap_or_default(), - ) - .await; - } - - WsMessage::Heartbeat if *client_type == ClientType::Agent => { - // Keepalive — nothing to do, the timeout is reset by receiving any message. - } - - // ── From Viewer ─────────────────────────────────────────────────── - WsMessage::HudCommand { - command, params, .. - } if *client_type == ClientType::Viewer => { - let forward = WsMessage::ForwardHudCommand { command, params }; - let json = serde_json::to_string(&forward).unwrap_or_default(); - if !state.send_to_agent(session_id, &json).await { - let _ = ws_session - .text( - serde_json::to_string(&WsMessage::Error { - message: "no agent connected for this session".into(), - }) - .unwrap_or_default(), - ) - .await; + // Forward to agent as a text message. + let forward = serde_json::json!({ + "msg_type": "forward_hud_command", + "command": command, + "params": params, + }); + if !state.send_to_agent(session_id, WsOutMessage::Text( + serde_json::to_string(&forward).unwrap_or_default() + )).await { + let _ = ws_session.text(serde_json::to_string(&serde_json::json!({ + "msg_type": "error", + "message": "no agent connected" + })).unwrap_or_default()).await; } } - - WsMessage::Resize { - width, height, .. - } if *client_type == ClientType::Viewer => { - let forward = WsMessage::ForwardResize { width, height }; - let json = serde_json::to_string(&forward).unwrap_or_default(); - if !state.send_to_agent(session_id, &json).await { - let _ = ws_session - .text( - serde_json::to_string(&WsMessage::Error { - message: "no agent connected for this session".into(), - }) - .unwrap_or_default(), - ) - .await; - } + ("viewer", "resize") => { + let forward = serde_json::json!({ + "msg_type": "forward_resize", + "width": v["width"], + "height": v["height"], + }); + state.send_to_agent(session_id, WsOutMessage::Text( + serde_json::to_string(&forward).unwrap_or_default() + )).await; + } + ("viewer", "heartbeat") => { + let _ = ws_session.text(serde_json::to_string(&serde_json::json!({ + "msg_type": "ack", + "message": "heartbeat" + })).unwrap_or_default()).await; } - // ── Heartbeat from Viewer ───────────────────────────────────────── - WsMessage::Heartbeat if *client_type == ClientType::Viewer => { - let _ = ws_session - .text( - serde_json::to_string(&WsMessage::Ack { - message: "heartbeat".into(), - }) - .unwrap_or_default(), - ) - .await; - } - - // ── Fallback ────────────────────────────────────────────────────── + // ── Unknown ──────────────────────────────────────────────────── _ => { - warn!( - "[ws] unexpected message type from {:?} for session {}", - client_type, session_id - ); - } - } -} - -/// Clean up when a client disconnects. -fn cleanup_connection( - state: &Arc, - session_id: &str, - client_type: &ClientType, - agent_id: &Option, - viewer_id: &Option, -) { - match client_type { - ClientType::Agent => { - if let Some(aid) = agent_id { - state.unregister_agent(aid); - } - state.unregister_agent_channel(session_id); - } - ClientType::Viewer => { - if let Some(vid) = viewer_id { - state.unregister_viewer(session_id, vid); - } + warn!("[ws] unexpected msg_type '{}' from {} for session {}", msg_type, client_type, session_id); } } }