server: ws/handler.rs — binary frame relay (zero-copy), text JSON for control; update api/sessions.rs for new WsOutMessage

This commit is contained in:
Butterfly Dev 2026-04-07 05:00:34 +00:00
parent 31a862b75b
commit 05cfe9e479
2 changed files with 165 additions and 242 deletions

View File

@ -2,8 +2,8 @@ use actix_web::{web, HttpResponse};
use serde::Deserialize; use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
use crate::models::{ApiResponse, Session, WsMessage}; use crate::models::{ApiResponse, Session};
use crate::state::AppState; use crate::state::{AppState, WsOutMessage};
/// `GET /api/sessions` — list every session. /// `GET /api/sessions` — list every session.
pub async fn list_sessions(state: web::Data<Arc<AppState>>) -> HttpResponse { pub async fn list_sessions(state: web::Data<Arc<AppState>>) -> HttpResponse {
@ -64,11 +64,11 @@ pub async fn send_hud_command(
let session_id = path.into_inner(); let session_id = path.into_inner();
// Build the forward message. // Build the forward message.
let msg = WsMessage::ForwardHudCommand { let json = match serde_json::to_string(&serde_json::json!({
command: body.command.clone(), "msg_type": "forward_hud_command",
params: body.params.clone(), "command": body.command,
}; "params": body.params,
let json = match serde_json::to_string(&msg) { })) {
Ok(j) => j, Ok(j) => j,
Err(e) => { Err(e) => {
return HttpResponse::InternalServerError() return HttpResponse::InternalServerError()
@ -77,7 +77,7 @@ pub async fn send_hud_command(
}; };
// Send through the agent channel. // Send through the agent channel.
if state.send_to_agent(&session_id, &json).await { if state.send_to_agent(&session_id, WsOutMessage::Text(json)).await {
HttpResponse::Ok().json(ApiResponse::ok("command forwarded")) HttpResponse::Ok().json(ApiResponse::ok("command forwarded"))
} else { } else {
HttpResponse::Conflict() HttpResponse::Conflict()

View File

@ -6,14 +6,12 @@ use futures::StreamExt;
use log::{info, warn}; use log::{info, warn};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use crate::models::{ClientType, WsMessage}; use crate::state::{WsOutMessage, MAX_BINARY_FRAME_SIZE};
/// ACTIX-WEB HTTP HANDLER /// ACTIX-WEB HTTP HANDLER
/// ///
/// Upgrades the HTTP connection to a WebSocket and spawns an async task /// Upgrades HTTP to WebSocket. Handles both binary frames (video relay) and
/// that reads frames from the client and dispatches them. /// text frames (JSON control messages).
///
/// The query parameter `client_type` must be `"viewer"` or `"agent"`.
pub async fn ws_index( pub async fn ws_index(
req: actix_web::HttpRequest, req: actix_web::HttpRequest,
body: actix_web::web::Payload, body: actix_web::web::Payload,
@ -22,19 +20,17 @@ pub async fn ws_index(
) -> Result<actix_web::HttpResponse, actix_web::Error> { ) -> Result<actix_web::HttpResponse, actix_web::Error> {
let session_id = path.into_inner(); let session_id = path.into_inner();
// Validate that the session exists.
if !state.sessions.contains_key(&session_id) { if !state.sessions.contains_key(&session_id) {
return Ok(actix_web::HttpResponse::NotFound().json( return Ok(actix_web::HttpResponse::NotFound().json(
crate::models::ApiResponse::<()>::err("session not found"), crate::models::ApiResponse::<()>::err("session not found"),
)); ));
} }
// Determine client type from query string.
let query_str = req.query_string(); let query_str = req.query_string();
let client_type = if query_str.contains("client_type=agent") { let client_type = if query_str.contains("client_type=agent") {
ClientType::Agent "agent"
} else { } else {
ClientType::Viewer "viewer"
}; };
let ip = req let ip = req
@ -43,71 +39,49 @@ pub async fn ws_index(
.unwrap_or("unknown") .unwrap_or("unknown")
.to_string(); .to_string();
// Perform the WebSocket upgrade.
let (response, session, msg_stream) = actix_ws::handle(&req, body)?; let (response, session, msg_stream) = actix_ws::handle(&req, body)?;
info!( info!("[ws] {} connected to session {} (ip={})", client_type, session_id, ip);
"[ws] {} connected to session {} (ip={})",
match client_type {
ClientType::Agent => "AGENT",
ClientType::Viewer => "VIEWER",
},
session_id,
ip
);
// ── Per-connection setup based on client type ────────────────────────── // Create mpsc channel for outgoing messages to this client.
let channel_cap = match client_type {
// Create an mpsc channel for outgoing messages to this client. "agent" => 64,
// Viewers get a larger buffer since they receive frames (high throughput). _ => 120,
// Agents get a smaller buffer since they only receive HUD commands (low frequency).
let channel_capacity = match client_type {
ClientType::Viewer => 120,
ClientType::Agent => 64,
}; };
let (tx, mut rx) = mpsc::channel::<String>(channel_capacity); let (tx, mut rx) = mpsc::channel::<WsOutMessage>(channel_cap);
// Track IDs for cleanup. // Track IDs for cleanup.
let mut agent_id_for_cleanup: Option<String> = None; let mut agent_id: Option<String> = None;
let mut viewer_id_for_cleanup: Option<String> = None; let mut viewer_id: Option<String> = None;
if client_type == ClientType::Agent { if client_type == "agent" {
let agent_id = uuid::Uuid::new_v4().to_string(); let aid = uuid::Uuid::new_v4().to_string();
let agent = crate::models::AgentConnection { let agent = crate::models::AgentConnection {
agent_id: agent_id.clone(), agent_id: aid.clone(),
session_id: session_id.clone(), session_id: session_id.clone(),
connected_at: chrono::Utc::now(), connected_at: chrono::Utc::now(),
ip_address: ip.clone(), ip_address: ip,
display_active: false, display_active: false,
audio_active: false, audio_active: false,
}; };
state.register_agent(agent); state.register_agent(agent);
state.register_agent_channel(&session_id, tx); state.register_agent_channel(&session_id, tx);
agent_id_for_cleanup = Some(agent_id); agent_id = Some(aid);
} else { } else {
let viewer_id = uuid::Uuid::new_v4().to_string(); let vid = uuid::Uuid::new_v4().to_string();
state.register_viewer(&session_id, &viewer_id, tx); state.register_viewer(&session_id, &vid, tx);
// Send the latest buffered frame to the new viewer immediately, // Send the latest buffered frame to the new viewer immediately.
// so they see something without waiting for the next frame from the agent.
if let Some(frame_data) = state.get_latest_frame(&session_id) { if let Some(frame_data) = state.get_latest_frame(&session_id) {
let msg = WsMessage::FrameBroadcast { let _ = tx.try_send(WsOutMessage::Binary(frame_data));
data: frame_data, }
content_type: "image/jpeg".into(), viewer_id = Some(vid);
};
let json = serde_json::to_string(&msg).unwrap_or_default();
// Best-effort immediate send on the channel (non-blocking).
let _ = tx.try_send(json);
} }
viewer_id_for_cleanup = Some(viewer_id); // Spawn the connection task.
}
// ── Spawn the connection task ──────────────────────────────────────────
let state_clone = state.clone(); let state_clone = state.clone();
let session_id_clone = session_id.clone(); let session_id_clone = session_id.clone();
let client_type_clone = client_type.clone(); let client_type_clone = client_type.to_string();
let timeout = Duration::from_secs(state.idle_timeout_secs); let timeout = Duration::from_secs(state.idle_timeout_secs);
actix_web::rt::spawn(async move { actix_web::rt::spawn(async move {
@ -116,253 +90,202 @@ pub async fn ws_index(
loop { loop {
tokio::select! { tokio::select! {
// ── Incoming WebSocket message from the client ───────────── // ── Incoming WebSocket message ──────────────────────────
ws_msg = msg_stream.next() => { ws_msg = msg_stream.next() => {
last_activity = Instant::now(); last_activity = Instant::now();
timeout_sleep = tokio::time::sleep(timeout); timeout_sleep = tokio::time::sleep(timeout);
match ws_msg { match ws_msg {
Some(Ok(Message::Text(text))) => { Some(Ok(Message::Text(text))) => {
handle_text_message( handle_text_message(&text, &session_id_clone, &client_type_clone, &state_clone, &session).await;
&text, }
&session_id_clone, Some(Ok(Message::Binary(data))) => {
&client_type_clone, handle_binary_message(&data, &session_id_clone, &client_type_clone, &state_clone).await;
&state_clone,
&session,
)
.await;
} }
Some(Ok(Message::Ping(bytes))) => { Some(Ok(Message::Ping(bytes))) => {
let _ = session.pong(&bytes).await; let _ = session.pong(&bytes).await;
} }
Some(Ok(Message::Close(reason))) => { Some(Ok(Message::Close(reason))) => {
info!( info!("[ws] {} disconnected from session {} ({:?})", client_type_clone, session_id_clone, reason);
"[ws] {:?} disconnected from session {} (close: {:?})",
client_type_clone, session_id_clone, reason
);
break; break;
} }
Some(Ok(Message::Binary(_))) => { Some(Ok(Message::Frame(_))) => {}
// We only use text (JSON) messages. Ignore binary.
}
Some(Err(e)) => { Some(Err(e)) => {
warn!("[ws] read error for {:?} on session {}: {}", client_type_clone, session_id_clone, e); warn!("[ws] read error {} on session {}: {}", client_type_clone, session_id_clone, e);
break;
}
None => {
// Stream ended (client disconnected).
break; break;
} }
None => { break; }
} }
} }
// ── Outgoing message from the broadcast/forward channel ──── // ── Outgoing message from broadcast/forward channel ────
out_msg = rx.recv() => { out_msg = rx.recv() => {
last_activity = Instant::now(); last_activity = Instant::now();
timeout_sleep = tokio::time::sleep(timeout); timeout_sleep = tokio::time::sleep(timeout);
match out_msg { match out_msg {
Some(text) => { Some(WsOutMessage::Binary(data)) => {
if session.binary(data).await.is_err() {
break;
}
}
Some(WsOutMessage::Text(text)) => {
if session.text(&text).await.is_err() { if session.text(&text).await.is_err() {
warn!("[ws] write failed for {:?} on session {}", client_type_clone, session_id_clone);
break; break;
} }
} }
None => { None => { break; }
// Channel closed (sender dropped during cleanup).
break;
}
} }
} }
// ── Idle timeout ─────────────────────────────────────────── // ── Idle timeout ───────────────────────────────────────
_ = &mut timeout_sleep => { _ = &mut timeout_sleep => {
warn!( warn!("[ws] {} timed out on session {}", client_type_clone, session_id_clone);
"[ws] {:?} timed out on session {} ({}s idle)",
client_type_clone, session_id_clone, timeout.as_secs()
);
break; break;
} }
} }
} }
// ── Cleanup ──────────────────────────────────────────────────────── // Cleanup.
cleanup_connection( match client_type_clone.as_str() {
&state_clone, "agent" => {
&session_id_clone, if let Some(aid) = agent_id {
&client_type_clone, state_clone.unregister_agent(&aid);
&agent_id_for_cleanup, }
&viewer_id_for_cleanup, state_clone.unregister_agent_channel(&session_id_clone);
); }
_ => {
// Best-effort close the WebSocket. if let Some(vid) = viewer_id {
state_clone.unregister_viewer(&session_id_clone, &vid);
}
}
}
let _ = session.close(None).await; let _ = session.close(None).await;
}); });
Ok(response) Ok(response)
} }
// ── Internal helpers ───────────────────────────────────────────────────────── // ── Binary frame handler (video relay) ────────────────────────────────────────
/// Parse and dispatch an incoming text (JSON) WebSocket message. /// Handle a binary WebSocket frame.
///
/// From agent: Raw video frame → push to buffer + broadcast to all viewers.
/// From viewer: Ignored (viewers shouldn't send binary frames).
async fn handle_binary_message(
data: &[u8],
session_id: &str,
client_type: &str,
state: &Arc<crate::state::AppState>,
) {
if client_type != "agent" {
return;
}
// Reject oversized frames.
if data.len() > MAX_BINARY_FRAME_SIZE {
warn!("[ws] oversized binary frame ({} bytes) from agent in session {}", data.len(), session_id);
return;
}
// Store in ring buffer (for late-joining viewers).
state.push_frame(session_id, data.to_vec());
// Broadcast to all viewers (zero-copy-ish: data is cloned per viewer).
state.broadcast_binary_frame(session_id, data.to_vec()).await;
}
// ── Text message handler (JSON control) ──────────────────────────────────────
/// Handle a text (JSON) WebSocket message.
async fn handle_text_message( async fn handle_text_message(
raw: &str, raw: &str,
session_id: &str, session_id: &str,
client_type: &ClientType, client_type: &str,
state: &Arc<crate::state::AppState>, state: &Arc<crate::state::AppState>,
ws_session: &Session, ws_session: &Session,
) { ) {
let msg: WsMessage = match serde_json::from_str(raw) { // Try to parse as JSON control message.
Ok(m) => m, let v: serde_json::Value = match serde_json::from_str(raw) {
Ok(v) => v,
Err(e) => { Err(e) => {
warn!( warn!("[ws] invalid JSON from {}: {} ({})", client_type, e, &raw[..raw.len().min(100)]);
"[ws] invalid message: {} ({})", let _ = ws_session.text(serde_json::to_string(&serde_json::json!({
e, "msg_type": "error",
raw.chars().take(120).collect::<String>() "message": format!("invalid JSON: {}", e)
); })).unwrap_or_default()).await;
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Error {
message: format!("invalid message: {}", e),
})
.unwrap_or_default(),
)
.await;
return; return;
} }
}; };
match msg { let msg_type = v["msg_type"].as_str().unwrap_or("");
// ── From Agent ────────────────────────────────────────────────────
WsMessage::DisplayFrame { data, .. } if *client_type == ClientType::Agent => {
state.push_frame(session_id, data.clone());
let broadcast = WsMessage::FrameBroadcast { match (client_type, msg_type) {
data, // ── Agent messages ─────────────────────────────────────────────
content_type: "image/jpeg".into(), ("agent", "agent_info") => {
}; let resolution = v["resolution"].as_str();
let json = serde_json::to_string(&broadcast).unwrap_or_default(); state.activate_session(session_id, resolution);
state.broadcast_to_viewers(session_id, &json).await;
// Notify viewers.
let update = serde_json::json!({
"msg_type": "session_update",
"session_id": session_id,
"status": "active",
"resolution": resolution,
});
state.broadcast_text(session_id, serde_json::to_string(&update).unwrap_or_default()).await;
let _ = ws_session.text(serde_json::to_string(&serde_json::json!({
"msg_type": "ack",
"message": "agent registered"
})).unwrap_or_default()).await;
info!("[ws] agent registered for session {} (encoder: {})", session_id, v["encoder"].as_str().unwrap_or("unknown"));
}
("agent", "heartbeat") => {
// Keepalive — timer reset at outer level.
} }
WsMessage::AudioFrame { data, .. } if *client_type == ClientType::Agent => { // ── Viewer messages ────────────────────────────────────────────
let broadcast = WsMessage::AudioBroadcast { ("viewer", "hud_command") => {
data, let command = v["command"].as_str().unwrap_or("");
content_type: "audio/opus".into(), let params = &v["params"];
};
let json = serde_json::to_string(&broadcast).unwrap_or_default(); // Forward to agent as a text message.
state.broadcast_to_viewers(session_id, &json).await; let forward = serde_json::json!({
"msg_type": "forward_hud_command",
"command": command,
"params": params,
});
if !state.send_to_agent(session_id, WsOutMessage::Text(
serde_json::to_string(&forward).unwrap_or_default()
)).await {
let _ = ws_session.text(serde_json::to_string(&serde_json::json!({
"msg_type": "error",
"message": "no agent connected"
})).unwrap_or_default()).await;
}
}
("viewer", "resize") => {
let forward = serde_json::json!({
"msg_type": "forward_resize",
"width": v["width"],
"height": v["height"],
});
state.send_to_agent(session_id, WsOutMessage::Text(
serde_json::to_string(&forward).unwrap_or_default()
)).await;
}
("viewer", "heartbeat") => {
let _ = ws_session.text(serde_json::to_string(&serde_json::json!({
"msg_type": "ack",
"message": "heartbeat"
})).unwrap_or_default()).await;
} }
WsMessage::AgentInfo { // ── Unknown ────────────────────────────────────────────────────
agent_id,
resolution,
..
} if *client_type == ClientType::Agent => {
state.activate_session(session_id, resolution.as_deref());
if let Some(session) = state.get_session(session_id) {
let update = WsMessage::SessionUpdate {
session_id: session_id.to_string(),
status: session.status,
resolution: session.resolution,
};
let json = serde_json::to_string(&update).unwrap_or_default();
state.broadcast_to_viewers(session_id, &json).await;
}
info!(
"[ws] agent {} reported for session {}",
agent_id, session_id
);
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Ack {
message: "agent registered".into(),
})
.unwrap_or_default(),
)
.await;
}
WsMessage::Heartbeat if *client_type == ClientType::Agent => {
// Keepalive — nothing to do, the timeout is reset by receiving any message.
}
// ── From Viewer ───────────────────────────────────────────────────
WsMessage::HudCommand {
command, params, ..
} if *client_type == ClientType::Viewer => {
let forward = WsMessage::ForwardHudCommand { command, params };
let json = serde_json::to_string(&forward).unwrap_or_default();
if !state.send_to_agent(session_id, &json).await {
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Error {
message: "no agent connected for this session".into(),
})
.unwrap_or_default(),
)
.await;
}
}
WsMessage::Resize {
width, height, ..
} if *client_type == ClientType::Viewer => {
let forward = WsMessage::ForwardResize { width, height };
let json = serde_json::to_string(&forward).unwrap_or_default();
if !state.send_to_agent(session_id, &json).await {
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Error {
message: "no agent connected for this session".into(),
})
.unwrap_or_default(),
)
.await;
}
}
// ── Heartbeat from Viewer ─────────────────────────────────────────
WsMessage::Heartbeat if *client_type == ClientType::Viewer => {
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Ack {
message: "heartbeat".into(),
})
.unwrap_or_default(),
)
.await;
}
// ── Fallback ──────────────────────────────────────────────────────
_ => { _ => {
warn!( warn!("[ws] unexpected msg_type '{}' from {} for session {}", msg_type, client_type, session_id);
"[ws] unexpected message type from {:?} for session {}",
client_type, session_id
);
}
}
}
/// Clean up when a client disconnects.
fn cleanup_connection(
state: &Arc<crate::state::AppState>,
session_id: &str,
client_type: &ClientType,
agent_id: &Option<String>,
viewer_id: &Option<String>,
) {
match client_type {
ClientType::Agent => {
if let Some(aid) = agent_id {
state.unregister_agent(aid);
}
state.unregister_agent_channel(session_id);
}
ClientType::Viewer => {
if let Some(vid) = viewer_id {
state.unregister_viewer(session_id, vid);
}
} }
} }
} }