ws/handler: implement real bidirectional relay — mpsc channels, viewer catch-up, select! loop

This commit is contained in:
Butterfly Dev 2026-04-07 03:57:45 +00:00
parent 29eda76675
commit 2344060d73

View File

@ -4,6 +4,7 @@ use std::time::{Duration, Instant};
use actix_ws::{Message, Session};
use futures::StreamExt;
use log::{info, warn};
use tokio::sync::mpsc;
use crate::models::{ClientType, WsMessage};
@ -43,9 +44,8 @@ pub async fn ws_index(
.to_string();
// Perform the WebSocket upgrade.
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?;
let (response, session, msg_stream) = actix_ws::handle(&req, body)?;
// Log the connection.
info!(
"[ws] {} connected to session {} (ip={})",
match client_type {
@ -56,7 +56,21 @@ pub async fn ws_index(
ip
);
// Register agent in shared state.
// ── Per-connection setup based on client type ──────────────────────────
// Create an mpsc channel for outgoing messages to this client.
// Viewers get a larger buffer since they receive frames (high throughput).
// Agents get a smaller buffer since they only receive HUD commands (low frequency).
let channel_capacity = match client_type {
ClientType::Viewer => 120,
ClientType::Agent => 64,
};
let (tx, mut rx) = mpsc::channel::<String>(channel_capacity);
// Track IDs for cleanup.
let mut agent_id_for_cleanup: Option<String> = None;
let mut viewer_id_for_cleanup: Option<String> = None;
if client_type == ClientType::Agent {
let agent_id = uuid::Uuid::new_v4().to_string();
let agent = crate::models::AgentConnection {
@ -68,63 +82,121 @@ pub async fn ws_index(
audio_active: false,
};
state.register_agent(agent);
state.register_agent_channel(&session_id, tx);
agent_id_for_cleanup = Some(agent_id);
} else {
let viewer_id = uuid::Uuid::new_v4().to_string();
state.register_viewer(&session_id, &viewer_id, tx);
// Send the latest buffered frame to the new viewer immediately,
// so they see something without waiting for the next frame from the agent.
if let Some(frame_data) = state.get_latest_frame(&session_id) {
let msg = WsMessage::FrameBroadcast {
data: frame_data,
content_type: "image/jpeg".into(),
};
let json = serde_json::to_string(&msg).unwrap_or_default();
// Best-effort immediate send on the channel (non-blocking).
let _ = tx.try_send(json);
}
viewer_id_for_cleanup = Some(viewer_id);
}
// Clone references for the spawned task.
// ── Spawn the connection task ──────────────────────────────────────────
let state_clone = state.clone();
let session_id_clone = session_id.clone();
let client_type_clone = client_type.clone();
let timeout = Duration::from_secs(state.idle_timeout_secs);
// Spawn the reader task receives messages FROM the client.
actix_web::rt::spawn(async move {
let mut last_heartbeat = Instant::now();
let timeout = Duration::from_secs(state_clone.idle_timeout_secs);
let mut last_activity = Instant::now();
let mut timeout_sleep = tokio::time::sleep(timeout);
while let Some(Ok(msg)) = msg_stream.next().await {
match msg {
Message::Text(text) => {
last_heartbeat = Instant::now();
handle_text_message(
&text,
&session_id_clone,
&client_type_clone,
&state_clone,
&mut session,
)
.await;
loop {
tokio::select! {
// ── Incoming WebSocket message from the client ─────────────
ws_msg = msg_stream.next() => {
last_activity = Instant::now();
timeout_sleep = tokio::time::sleep(timeout);
match ws_msg {
Some(Ok(Message::Text(text))) => {
handle_text_message(
&text,
&session_id_clone,
&client_type_clone,
&state_clone,
&session,
)
.await;
}
Some(Ok(Message::Ping(bytes))) => {
let _ = session.pong(&bytes).await;
}
Some(Ok(Message::Close(reason))) => {
info!(
"[ws] {:?} disconnected from session {} (close: {:?})",
client_type_clone, session_id_clone, reason
);
break;
}
Some(Ok(Message::Binary(_))) => {
// We only use text (JSON) messages. Ignore binary.
}
Some(Err(e)) => {
warn!("[ws] read error for {:?} on session {}: {}", client_type_clone, session_id_clone, e);
break;
}
None => {
// Stream ended (client disconnected).
break;
}
}
}
Message::Ping(bytes) => {
last_heartbeat = Instant::now();
let _ = session.pong(&bytes).await;
// ── Outgoing message from the broadcast/forward channel ────
out_msg = rx.recv() => {
last_activity = Instant::now();
timeout_sleep = tokio::time::sleep(timeout);
match out_msg {
Some(text) => {
if session.text(&text).await.is_err() {
warn!("[ws] write failed for {:?} on session {}", client_type_clone, session_id_clone);
break;
}
}
None => {
// Channel closed (sender dropped during cleanup).
break;
}
}
}
Message::Close(reason) => {
info!(
"[ws] {:?} disconnected from session {} (close: {:?})",
client_type_clone, session_id_clone, reason
// ── Idle timeout ───────────────────────────────────────────
_ = &mut timeout_sleep => {
warn!(
"[ws] {:?} timed out on session {} ({}s idle)",
client_type_clone, session_id_clone, timeout.as_secs()
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
return;
break;
}
_ => {}
}
// Check for idle timeout.
if last_heartbeat.elapsed() > timeout {
warn!(
"[ws] {:?} timed out on session {}",
client_type_clone, session_id_clone
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
let _ = session.close(None).await;
return;
}
}
info!(
"[ws] {:?} stream ended for session {}",
client_type_clone, session_id_clone
// ── Cleanup ────────────────────────────────────────────────────────
cleanup_connection(
&state_clone,
&session_id_clone,
&client_type_clone,
&agent_id_for_cleanup,
&viewer_id_for_cleanup,
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
// Best-effort close the WebSocket.
let _ = session.close(None).await;
});
Ok(response)
@ -138,16 +210,23 @@ async fn handle_text_message(
session_id: &str,
client_type: &ClientType,
state: &Arc<crate::state::AppState>,
ws_session: &mut Session,
ws_session: &Session,
) {
let msg: WsMessage = match serde_json::from_str(raw) {
Ok(m) => m,
Err(e) => {
warn!("[ws] invalid message: {} ({})", e, raw.chars().take(120).collect::<String>());
warn!(
"[ws] invalid message: {} ({})",
e,
raw.chars().take(120).collect::<String>()
);
let _ = ws_session
.text(serde_json::to_string(&WsMessage::Error {
message: format!("invalid message: {}", e),
}).unwrap_or_default())
.text(
serde_json::to_string(&WsMessage::Error {
message: format!("invalid message: {}", e),
})
.unwrap_or_default(),
)
.await;
return;
}
@ -157,117 +236,133 @@ async fn handle_text_message(
// ── From Agent ────────────────────────────────────────────────────
WsMessage::DisplayFrame { data, .. } if *client_type == ClientType::Agent => {
state.push_frame(session_id, data.clone());
broadcast_to_viewers(state, session_id, &WsMessage::FrameBroadcast {
let broadcast = WsMessage::FrameBroadcast {
data,
content_type: "image/jpeg".into(),
}).await;
};
let json = serde_json::to_string(&broadcast).unwrap_or_default();
state.broadcast_to_viewers(session_id, &json).await;
}
WsMessage::AudioFrame { data, .. } if *client_type == ClientType::Agent => {
broadcast_to_viewers(state, session_id, &WsMessage::AudioBroadcast {
let broadcast = WsMessage::AudioBroadcast {
data,
content_type: "audio/opus".into(),
}).await;
};
let json = serde_json::to_string(&broadcast).unwrap_or_default();
state.broadcast_to_viewers(session_id, &json).await;
}
WsMessage::AgentInfo { agent_id, resolution, .. } if *client_type == ClientType::Agent => {
WsMessage::AgentInfo {
agent_id,
resolution,
..
} if *client_type == ClientType::Agent => {
state.activate_session(session_id, resolution.as_deref());
if let Some(session) = state.get_session(session_id) {
broadcast_to_viewers(state, session_id, &WsMessage::SessionUpdate {
let update = WsMessage::SessionUpdate {
session_id: session_id.to_string(),
status: session.status,
resolution: session.resolution,
}).await;
};
let json = serde_json::to_string(&update).unwrap_or_default();
state.broadcast_to_viewers(session_id, &json).await;
}
info!("[ws] agent {} reported for session {}", agent_id, session_id);
let _ = ws_session.text(serde_json::to_string(&WsMessage::Ack {
message: "agent registered".into(),
}).unwrap_or_default()).await;
info!(
"[ws] agent {} reported for session {}",
agent_id, session_id
);
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Ack {
message: "agent registered".into(),
})
.unwrap_or_default(),
)
.await;
}
WsMessage::Heartbeat if *client_type == ClientType::Agent => {
// Keepalive — nothing to do.
// Keepalive — nothing to do, the timeout is reset by receiving any message.
}
// ── From Viewer ───────────────────────────────────────────────────
WsMessage::HudCommand { command, params, .. } if *client_type == ClientType::Viewer => {
forward_to_agent(state, session_id, &WsMessage::ForwardHudCommand { command, params }).await;
WsMessage::HudCommand {
command, params, ..
} if *client_type == ClientType::Viewer => {
let forward = WsMessage::ForwardHudCommand { command, params };
let json = serde_json::to_string(&forward).unwrap_or_default();
if !state.send_to_agent(session_id, &json).await {
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Error {
message: "no agent connected for this session".into(),
})
.unwrap_or_default(),
)
.await;
}
}
WsMessage::Resize { width, height, .. } if *client_type == ClientType::Viewer => {
forward_to_agent(state, session_id, &WsMessage::ForwardResize { width, height }).await;
WsMessage::Resize {
width, height, ..
} if *client_type == ClientType::Viewer => {
let forward = WsMessage::ForwardResize { width, height };
let json = serde_json::to_string(&forward).unwrap_or_default();
if !state.send_to_agent(session_id, &json).await {
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Error {
message: "no agent connected for this session".into(),
})
.unwrap_or_default(),
)
.await;
}
}
// ── Heartbeat from Viewer ─────────────────────────────────────────
WsMessage::Heartbeat if *client_type == ClientType::Viewer => {
let _ = ws_session
.text(
serde_json::to_string(&WsMessage::Ack {
message: "heartbeat".into(),
})
.unwrap_or_default(),
)
.await;
}
// ── Fallback ──────────────────────────────────────────────────────
_ => {
warn!("[ws] unexpected message type from {:?} for session {}", client_type, session_id);
warn!(
"[ws] unexpected message type from {:?} for session {}",
client_type, session_id
);
}
}
}
/// Broadcast a message to all **viewer** WebSocket connections for a session.
///
/// TODO: maintain a registry of per-viewer Session senders in AppState.
/// For now we log and rely on the frame buffer for new viewers.
async fn broadcast_to_viewers(
_state: &Arc<crate::state::AppState>,
session_id: &str,
msg: &WsMessage,
) {
log::debug!("[ws] broadcast to viewers of session {}: {:?}", session_id, msg.msg_type_debug());
}
/// Forward a message to the agent connected to a session.
///
/// TODO: maintain per-agent mpsc channel in AppState.
async fn forward_to_agent(
_state: &Arc<crate::state::AppState>,
session_id: &str,
msg: &WsMessage,
) {
log::info!(
"[ws] forward to agent in session {}: {:?}",
session_id,
msg.msg_type_debug()
);
}
/// Clean up when a client disconnects.
fn cleanup(state: &Arc<crate::state::AppState>, session_id: &str, client_type: &ClientType) {
if *client_type == ClientType::Agent {
let agent_id = state
.agents
.iter()
.find(|r| r.session_id == session_id)
.map(|r| r.agent_id.clone());
if let Some(aid) = agent_id {
state.unregister_agent(&aid);
}
}
}
/// Helper trait to get a human-readable tag for logging.
trait MsgTypeDebug {
fn msg_type_debug(&self) -> &'static str;
}
impl MsgTypeDebug for WsMessage {
fn msg_type_debug(&self) -> &'static str {
match self {
WsMessage::DisplayFrame { .. } => "display_frame",
WsMessage::AudioFrame { .. } => "audio_frame",
WsMessage::AgentInfo { .. } => "agent_info",
WsMessage::Heartbeat => "heartbeat",
WsMessage::HudCommand { .. } => "hud_command",
WsMessage::Resize { .. } => "resize",
WsMessage::FrameBroadcast { .. } => "frame_broadcast",
WsMessage::AudioBroadcast { .. } => "audio_broadcast",
WsMessage::SessionUpdate { .. } => "session_update",
WsMessage::ForwardHudCommand { .. } => "forward_hud_command",
WsMessage::ForwardResize { .. } => "forward_resize",
WsMessage::StreamControl { .. } => "stream_control",
WsMessage::Error { .. } => "error",
WsMessage::Ack { .. } => "ack",
fn cleanup_connection(
state: &Arc<crate::state::AppState>,
session_id: &str,
client_type: &ClientType,
agent_id: &Option<String>,
viewer_id: &Option<String>,
) {
match client_type {
ClientType::Agent => {
if let Some(aid) = agent_id {
state.unregister_agent(aid);
}
state.unregister_agent_channel(session_id);
}
ClientType::Viewer => {
if let Some(vid) = viewer_id {
state.unregister_viewer(session_id, vid);
}
}
}
}