use std::sync::Arc; use std::time::{Duration, Instant}; use actix_ws::{Message, Session}; use futures::StreamExt; use log::{info, warn}; use crate::models::{ClientType, WsMessage}; /// ACTIX-WEB HTTP HANDLER /// /// Upgrades the HTTP connection to a WebSocket and spawns an async task /// that reads frames from the client and dispatches them. /// /// The query parameter `client_type` must be `"viewer"` or `"agent"`. pub async fn ws_index( req: actix_web::HttpRequest, body: actix_web::web::Payload, state: actix_web::web::Data>, path: actix_web::web::Path, ) -> Result { let session_id = path.into_inner(); // Validate that the session exists. if !state.sessions.contains_key(&session_id) { return Ok(actix_web::HttpResponse::NotFound().json( crate::models::ApiResponse::<()>::err("session not found"), )); } // Determine client type from query string. let query_str = req.query_string(); let client_type = if query_str.contains("client_type=agent") { ClientType::Agent } else { ClientType::Viewer }; let ip = req .connection_info() .realip_remote_addr() .unwrap_or("unknown") .to_string(); // Perform the WebSocket upgrade. let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?; // Log the connection. info!( "[ws] {} connected to session {} (ip={})", match client_type { ClientType::Agent => "AGENT", ClientType::Viewer => "VIEWER", }, session_id, ip ); // Register agent in shared state. if client_type == ClientType::Agent { let agent_id = uuid::Uuid::new_v4().to_string(); let agent = crate::models::AgentConnection { agent_id: agent_id.clone(), session_id: session_id.clone(), connected_at: chrono::Utc::now(), ip_address: ip.clone(), display_active: false, audio_active: false, }; state.register_agent(agent); } // Clone references for the spawned task. let state_clone = state.clone(); let session_id_clone = session_id.clone(); let client_type_clone = client_type.clone(); // Spawn the reader task – receives messages FROM the client. actix_web::rt::spawn(async move { let mut last_heartbeat = Instant::now(); let timeout = Duration::from_secs(state_clone.idle_timeout_secs); while let Some(Ok(msg)) = msg_stream.next().await { match msg { Message::Text(text) => { last_heartbeat = Instant::now(); handle_text_message( &text, &session_id_clone, &client_type_clone, &state_clone, &mut session, ) .await; } Message::Ping(bytes) => { last_heartbeat = Instant::now(); let _ = session.pong(&bytes).await; } Message::Close(reason) => { info!( "[ws] {:?} disconnected from session {} (close: {:?})", client_type_clone, session_id_clone, reason ); cleanup(&state_clone, &session_id_clone, &client_type_clone); return; } _ => {} } // Check for idle timeout. if last_heartbeat.elapsed() > timeout { warn!( "[ws] {:?} timed out on session {}", client_type_clone, session_id_clone ); cleanup(&state_clone, &session_id_clone, &client_type_clone); let _ = session.close(None).await; return; } } info!( "[ws] {:?} stream ended for session {}", client_type_clone, session_id_clone ); cleanup(&state_clone, &session_id_clone, &client_type_clone); }); Ok(response) } // ── Internal helpers ───────────────────────────────────────────────────────── /// Parse and dispatch an incoming text (JSON) WebSocket message. async fn handle_text_message( raw: &str, session_id: &str, client_type: &ClientType, state: &Arc, ws_session: &mut Session, ) { let msg: WsMessage = match serde_json::from_str(raw) { Ok(m) => m, Err(e) => { warn!("[ws] invalid message: {} ({})", e, raw.chars().take(120).collect::()); let _ = ws_session .text(serde_json::to_string(&WsMessage::Error { message: format!("invalid message: {}", e), }).unwrap_or_default()) .await; return; } }; match msg { // ── From Agent ──────────────────────────────────────────────────── WsMessage::DisplayFrame { data, .. } if *client_type == ClientType::Agent => { state.push_frame(session_id, data.clone()); broadcast_to_viewers(state, session_id, &WsMessage::FrameBroadcast { data, content_type: "image/jpeg".into(), }).await; } WsMessage::AudioFrame { data, .. } if *client_type == ClientType::Agent => { broadcast_to_viewers(state, session_id, &WsMessage::AudioBroadcast { data, content_type: "audio/opus".into(), }).await; } WsMessage::AgentInfo { agent_id, resolution, .. } if *client_type == ClientType::Agent => { state.activate_session(session_id, resolution.as_deref()); if let Some(session) = state.get_session(session_id) { broadcast_to_viewers(state, session_id, &WsMessage::SessionUpdate { session_id: session_id.to_string(), status: session.status, resolution: session.resolution, }).await; } info!("[ws] agent {} reported for session {}", agent_id, session_id); let _ = ws_session.text(serde_json::to_string(&WsMessage::Ack { message: "agent registered".into(), }).unwrap_or_default()).await; } WsMessage::Heartbeat if *client_type == ClientType::Agent => { // Keepalive — nothing to do. } // ── From Viewer ─────────────────────────────────────────────────── WsMessage::HudCommand { command, params, .. } if *client_type == ClientType::Viewer => { forward_to_agent(state, session_id, &WsMessage::ForwardHudCommand { command, params }).await; } WsMessage::Resize { width, height, .. } if *client_type == ClientType::Viewer => { forward_to_agent(state, session_id, &WsMessage::ForwardResize { width, height }).await; } // ── Fallback ────────────────────────────────────────────────────── _ => { warn!("[ws] unexpected message type from {:?} for session {}", client_type, session_id); } } } /// Broadcast a message to all **viewer** WebSocket connections for a session. /// /// TODO: maintain a registry of per-viewer Session senders in AppState. /// For now we log and rely on the frame buffer for new viewers. async fn broadcast_to_viewers( _state: &Arc, session_id: &str, msg: &WsMessage, ) { log::debug!("[ws] broadcast to viewers of session {}: {:?}", session_id, msg.msg_type_debug()); } /// Forward a message to the agent connected to a session. /// /// TODO: maintain per-agent mpsc channel in AppState. async fn forward_to_agent( _state: &Arc, session_id: &str, msg: &WsMessage, ) { log::info!( "[ws] forward to agent in session {}: {:?}", session_id, msg.msg_type_debug() ); } /// Clean up when a client disconnects. fn cleanup(state: &Arc, session_id: &str, client_type: &ClientType) { if *client_type == ClientType::Agent { let agent_id = state .agents .iter() .find(|r| r.session_id == session_id) .map(|r| r.agent_id.clone()); if let Some(aid) = agent_id { state.unregister_agent(&aid); } } } /// Helper trait to get a human-readable tag for logging. trait MsgTypeDebug { fn msg_type_debug(&self) -> &'static str; } impl MsgTypeDebug for WsMessage { fn msg_type_debug(&self) -> &'static str { match self { WsMessage::DisplayFrame { .. } => "display_frame", WsMessage::AudioFrame { .. } => "audio_frame", WsMessage::AgentInfo { .. } => "agent_info", WsMessage::Heartbeat => "heartbeat", WsMessage::HudCommand { .. } => "hud_command", WsMessage::Resize { .. } => "resize", WsMessage::FrameBroadcast { .. } => "frame_broadcast", WsMessage::AudioBroadcast { .. } => "audio_broadcast", WsMessage::SessionUpdate { .. } => "session_update", WsMessage::ForwardHudCommand { .. } => "forward_hud_command", WsMessage::ForwardResize { .. } => "forward_resize", WsMessage::StreamControl { .. } => "stream_control", WsMessage::Error { .. } => "error", WsMessage::Ack { .. } => "ack", } } }