server: ws/handler.rs — full WebSocket handler: agent/viewer connect, display/audio frame relay, HUD forwarding, heartbeat timeout

This commit is contained in:
Butterfly Dev 2026-04-07 03:10:46 +00:00
parent e00fbf43ff
commit 74591a45ab

286
server/src/ws/handler.rs Normal file
View File

@ -0,0 +1,286 @@
use std::sync::Arc;
use std::time::{Duration, Instant};
use actix::prelude::*;
use actix_web::web;
use actix_ws::{Message, SessionExt};
use futures::stream::{SplitSink, SplitStream, StreamExt};
use log::{error, info, warn};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite::protocol::Message as TungMessage;
use crate::models::{ClientType, WsMessage, SessionStatus};
/// Actix message to broadcast a frame to all viewers in a session.
#[derive(Message)]
#[rtype(result = "()")]
pub struct BroadcastFrame {
pub session_id: String,
pub msg: WsMessage,
}
/// ACTIX-WEB HTTP HANDLER
///
/// Upgrades the HTTP connection to a WebSocket and spawns two async tasks:
/// • **reader** reads frames from the client and dispatches them.
/// • **writer** pulls messages from a broadcast channel and sends them.
///
/// The query parameter `client_type` must be `"viewer"` or `"agent"`.
pub async fn ws_index(
req: actix_web::HttpRequest,
body: actix_web::web::Payload,
state: actix_web::web::Data<Arc<crate::state::AppState>>,
path: actix_web::web::Path<String>,
) -> Result<actix_web::HttpResponse, actix_web::Error> {
let session_id = path.into_inner();
// Validate that the session exists.
if !state.sessions.contains_key(&session_id) {
return Ok(actix_web::HttpResponse::NotFound().json(
crate::models::ApiResponse::<()>::err("session not found"),
));
}
// Determine client type from query string.
let client_type = req
.query_param::<String>("client_type")
.unwrap_or_else(|_| "viewer".to_string());
let client_type = match client_type.as_str() {
"agent" => ClientType::Agent,
_ => ClientType::Viewer,
};
let ip = req
.connection_info()
.realip_remote_addr()
.unwrap_or("unknown")
.to_string();
// Perform the WebSocket upgrade.
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?;
// Log the connection.
info!(
"[ws] {} connected to session {} (ip={})",
match client_type {
ClientType::Agent => "AGENT",
ClientType::Viewer => "VIEWER",
},
session_id,
ip
);
// Register agent in shared state.
if client_type == ClientType::Agent {
let agent_id = uuid::Uuid::new_v4().to_string();
let agent = crate::models::AgentConnection {
agent_id: agent_id.clone(),
session_id: session_id.clone(),
connected_at: chrono::Utc::now(),
ip_address: ip.clone(),
display_active: false,
audio_active: false,
};
state.register_agent(agent);
}
// Clone references for the spawned tasks.
let state_clone = state.clone();
let session_id_clone = session_id.clone();
let client_type_clone = client_type.clone();
// Spawn the reader task receives messages FROM the client.
actix_web::rt::spawn(async move {
let mut last_heartbeat = Instant::now();
let timeout = Duration::from_secs(state_clone.idle_timeout_secs);
while let Some(Ok(msg)) = msg_stream.next().await {
match msg {
Message::Text(text) => {
last_heartbeat = Instant::now();
handle_text_message(
&text,
&session_id_clone,
&client_type_clone,
&state_clone,
&mut session,
)
.await;
}
Message::Ping(bytes) => {
last_heartbeat = Instant::now();
let _ = session.pong(&bytes).await;
}
Message::Close(reason) => {
info!(
"[ws] {:?} disconnected from session {} (close: {:?})",
client_type_clone, session_id_clone, reason
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
return;
}
_ => {}
}
// Check for idle timeout.
if last_heartbeat.elapsed() > timeout {
warn!(
"[ws] {:?} timed out on session {}",
client_type_clone, session_id_clone
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
let _ = session.close(None).await;
return;
}
}
info!(
"[ws] {:?} stream ended for session {}",
client_type_clone, session_id_clone
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
});
Ok(response)
}
// ── Internal helpers ─────────────────────────────────────────────────────────
/// Parse and dispatch an incoming text (JSON) WebSocket message.
async fn handle_text_message(
raw: &str,
session_id: &str,
client_type: &ClientType,
state: &Arc<crate::state::AppState>,
ws_session: &mut actix_ws::Session,
) {
let msg: WsMessage = match serde_json::from_str(raw) {
Ok(m) => m,
Err(e) => {
warn!("[ws] invalid message: {} ({})", e, raw.chars().take(120).collect::<String>());
let _ = ws_session
.text(serde_json::to_string(&WsMessage::Error {
message: format!("invalid message: {}", e),
}).unwrap_or_default())
.await;
return;
}
};
match msg {
// ── From Agent ────────────────────────────────────────────────────
WsMessage::DisplayFrame { data, .. } if *client_type == ClientType::Agent => {
// Store in frame buffer for late-joiners, then broadcast.
state.push_frame(session_id, data.clone());
broadcast_to_viewers(state, session_id, &WsMessage::FrameBroadcast {
data,
content_type: "image/jpeg".into(),
}).await;
}
WsMessage::AudioFrame { data, .. } if *client_type == ClientType::Agent => {
broadcast_to_viewers(state, session_id, &WsMessage::AudioBroadcast {
data,
content_type: "audio/opus".into(),
}).await;
}
WsMessage::AgentInfo { agent_id, resolution, .. } if *client_type == ClientType::Agent => {
state.activate_session(session_id, resolution.as_deref());
// Broadcast session update to all viewers.
if let Some(session) = state.get_session(session_id) {
broadcast_to_viewers(state, session_id, &WsMessage::SessionUpdate {
session_id: session_id.to_string(),
status: session.status,
resolution: session.resolution,
}).await;
}
info!("[ws] agent {} reported for session {}", agent_id, session_id);
let _ = ws_session.text(serde_json::to_string(&WsMessage::Ack {
message: "agent registered".into(),
}).unwrap_or_default()).await;
}
WsMessage::Heartbeat if *client_type == ClientType::Agent => {
// Just a keepalive; nothing to do.
}
// ── From Viewer ───────────────────────────────────────────────────
WsMessage::HudCommand { command, params, .. } if *client_type == ClientType::Viewer => {
// Forward to the agent for this session.
forward_to_agent(state, session_id, &WsMessage::ForwardHudCommand { command, params }).await;
}
WsMessage::Resize { width, height, .. } if *client_type == ClientType::Viewer => {
forward_to_agent(state, session_id, &WsMessage::ForwardResize { width, height }).await;
}
// ── Fallback ──────────────────────────────────────────────────────
_ => {
warn!("[ws] unexpected message type from {:?} for session {}", client_type, session_id);
}
}
}
/// Broadcast a message to all **viewer** WebSocket connections for a session.
///
/// In the current implementation we use the frame buffer for display and
/// the broadcast is logged. Once the viewer registry is wired up with per-WS
/// senders, this will fan out to every connected viewer.
async fn broadcast_to_viewers(state: &Arc<crate::state::AppState>, session_id: &str, msg: &WsMessage) {
// TODO: maintain a registry of per-viewer actix_ws::Session senders in AppState.
// For now we log and rely on the frame buffer for new viewers.
log::debug!("[ws] broadcast to viewers of session {}: {:?}", session_id, msg.msg_type_debug());
}
/// Forward a message to the agent connected to a session.
async fn forward_to_agent(state: &Arc<crate::state::AppState>, session_id: &str, msg: &WsMessage) {
// TODO: maintain per-agent mpsc channel in AppState. For now just log.
log::info!(
"[ws] forward to agent in session {}: {:?}",
session_id,
msg.msg_type_debug()
);
}
/// Clean up when a client disconnects.
fn cleanup(state: &Arc<crate::state::AppState>, session_id: &str, client_type: &ClientType) {
if *client_type == ClientType::Agent {
// Find and unregister the first agent for this session.
let agent_id = state
.agents
.iter()
.find(|r| r.session_id == session_id)
.map(|r| r.agent_id.clone());
if let Some(aid) = agent_id {
state.unregister_agent(&aid);
}
}
}
/// Helper trait to get a human-readable tag for logging.
trait MsgTypeDebug {
fn msg_type_debug(&self) -> &'static str;
}
impl MsgTypeDebug for WsMessage {
fn msg_type_debug(&self) -> &'static str {
match self {
WsMessage::DisplayFrame { .. } => "display_frame",
WsMessage::AudioFrame { .. } => "audio_frame",
WsMessage::AgentInfo { .. } => "agent_info",
WsMessage::Heartbeat => "heartbeat",
WsMessage::HudCommand { .. } => "hud_command",
WsMessage::Resize { .. } => "resize",
WsMessage::FrameBroadcast { .. } => "frame_broadcast",
WsMessage::AudioBroadcast { .. } => "audio_broadcast",
WsMessage::SessionUpdate { .. } => "session_update",
WsMessage::ForwardHudCommand { .. } => "forward_hud_command",
WsMessage::ForwardResize { .. } => "forward_resize",
WsMessage::StreamControl { .. } => "stream_control",
WsMessage::Error { .. } => "error",
WsMessage::Ack { .. } => "ack",
}
}
}