projects/server/src/ws/handler.rs

274 lines
9.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::Arc;
use std::time::{Duration, Instant};
use actix_ws::{Message, Session};
use futures::StreamExt;
use log::{info, warn};
use crate::models::{ClientType, WsMessage};
/// ACTIX-WEB HTTP HANDLER
///
/// Upgrades the HTTP connection to a WebSocket and spawns an async task
/// that reads frames from the client and dispatches them.
///
/// The query parameter `client_type` must be `"viewer"` or `"agent"`.
pub async fn ws_index(
req: actix_web::HttpRequest,
body: actix_web::web::Payload,
state: actix_web::web::Data<Arc<crate::state::AppState>>,
path: actix_web::web::Path<String>,
) -> Result<actix_web::HttpResponse, actix_web::Error> {
let session_id = path.into_inner();
// Validate that the session exists.
if !state.sessions.contains_key(&session_id) {
return Ok(actix_web::HttpResponse::NotFound().json(
crate::models::ApiResponse::<()>::err("session not found"),
));
}
// Determine client type from query string.
let query_str = req.query_string();
let client_type = if query_str.contains("client_type=agent") {
ClientType::Agent
} else {
ClientType::Viewer
};
let ip = req
.connection_info()
.realip_remote_addr()
.unwrap_or("unknown")
.to_string();
// Perform the WebSocket upgrade.
let (response, mut session, mut msg_stream) = actix_ws::handle(&req, body)?;
// Log the connection.
info!(
"[ws] {} connected to session {} (ip={})",
match client_type {
ClientType::Agent => "AGENT",
ClientType::Viewer => "VIEWER",
},
session_id,
ip
);
// Register agent in shared state.
if client_type == ClientType::Agent {
let agent_id = uuid::Uuid::new_v4().to_string();
let agent = crate::models::AgentConnection {
agent_id: agent_id.clone(),
session_id: session_id.clone(),
connected_at: chrono::Utc::now(),
ip_address: ip.clone(),
display_active: false,
audio_active: false,
};
state.register_agent(agent);
}
// Clone references for the spawned task.
let state_clone = state.clone();
let session_id_clone = session_id.clone();
let client_type_clone = client_type.clone();
// Spawn the reader task receives messages FROM the client.
actix_web::rt::spawn(async move {
let mut last_heartbeat = Instant::now();
let timeout = Duration::from_secs(state_clone.idle_timeout_secs);
while let Some(Ok(msg)) = msg_stream.next().await {
match msg {
Message::Text(text) => {
last_heartbeat = Instant::now();
handle_text_message(
&text,
&session_id_clone,
&client_type_clone,
&state_clone,
&mut session,
)
.await;
}
Message::Ping(bytes) => {
last_heartbeat = Instant::now();
let _ = session.pong(&bytes).await;
}
Message::Close(reason) => {
info!(
"[ws] {:?} disconnected from session {} (close: {:?})",
client_type_clone, session_id_clone, reason
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
return;
}
_ => {}
}
// Check for idle timeout.
if last_heartbeat.elapsed() > timeout {
warn!(
"[ws] {:?} timed out on session {}",
client_type_clone, session_id_clone
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
let _ = session.close(None).await;
return;
}
}
info!(
"[ws] {:?} stream ended for session {}",
client_type_clone, session_id_clone
);
cleanup(&state_clone, &session_id_clone, &client_type_clone);
});
Ok(response)
}
// ── Internal helpers ─────────────────────────────────────────────────────────
/// Parse and dispatch an incoming text (JSON) WebSocket message.
async fn handle_text_message(
raw: &str,
session_id: &str,
client_type: &ClientType,
state: &Arc<crate::state::AppState>,
ws_session: &mut Session,
) {
let msg: WsMessage = match serde_json::from_str(raw) {
Ok(m) => m,
Err(e) => {
warn!("[ws] invalid message: {} ({})", e, raw.chars().take(120).collect::<String>());
let _ = ws_session
.text(serde_json::to_string(&WsMessage::Error {
message: format!("invalid message: {}", e),
}).unwrap_or_default())
.await;
return;
}
};
match msg {
// ── From Agent ────────────────────────────────────────────────────
WsMessage::DisplayFrame { data, .. } if *client_type == ClientType::Agent => {
state.push_frame(session_id, data.clone());
broadcast_to_viewers(state, session_id, &WsMessage::FrameBroadcast {
data,
content_type: "image/jpeg".into(),
}).await;
}
WsMessage::AudioFrame { data, .. } if *client_type == ClientType::Agent => {
broadcast_to_viewers(state, session_id, &WsMessage::AudioBroadcast {
data,
content_type: "audio/opus".into(),
}).await;
}
WsMessage::AgentInfo { agent_id, resolution, .. } if *client_type == ClientType::Agent => {
state.activate_session(session_id, resolution.as_deref());
if let Some(session) = state.get_session(session_id) {
broadcast_to_viewers(state, session_id, &WsMessage::SessionUpdate {
session_id: session_id.to_string(),
status: session.status,
resolution: session.resolution,
}).await;
}
info!("[ws] agent {} reported for session {}", agent_id, session_id);
let _ = ws_session.text(serde_json::to_string(&WsMessage::Ack {
message: "agent registered".into(),
}).unwrap_or_default()).await;
}
WsMessage::Heartbeat if *client_type == ClientType::Agent => {
// Keepalive — nothing to do.
}
// ── From Viewer ───────────────────────────────────────────────────
WsMessage::HudCommand { command, params, .. } if *client_type == ClientType::Viewer => {
forward_to_agent(state, session_id, &WsMessage::ForwardHudCommand { command, params }).await;
}
WsMessage::Resize { width, height, .. } if *client_type == ClientType::Viewer => {
forward_to_agent(state, session_id, &WsMessage::ForwardResize { width, height }).await;
}
// ── Fallback ──────────────────────────────────────────────────────
_ => {
warn!("[ws] unexpected message type from {:?} for session {}", client_type, session_id);
}
}
}
/// Broadcast a message to all **viewer** WebSocket connections for a session.
///
/// TODO: maintain a registry of per-viewer Session senders in AppState.
/// For now we log and rely on the frame buffer for new viewers.
async fn broadcast_to_viewers(
_state: &Arc<crate::state::AppState>,
session_id: &str,
msg: &WsMessage,
) {
log::debug!("[ws] broadcast to viewers of session {}: {:?}", session_id, msg.msg_type_debug());
}
/// Forward a message to the agent connected to a session.
///
/// TODO: maintain per-agent mpsc channel in AppState.
async fn forward_to_agent(
_state: &Arc<crate::state::AppState>,
session_id: &str,
msg: &WsMessage,
) {
log::info!(
"[ws] forward to agent in session {}: {:?}",
session_id,
msg.msg_type_debug()
);
}
/// Clean up when a client disconnects.
fn cleanup(state: &Arc<crate::state::AppState>, session_id: &str, client_type: &ClientType) {
if *client_type == ClientType::Agent {
let agent_id = state
.agents
.iter()
.find(|r| r.session_id == session_id)
.map(|r| r.agent_id.clone());
if let Some(aid) = agent_id {
state.unregister_agent(&aid);
}
}
}
/// Helper trait to get a human-readable tag for logging.
trait MsgTypeDebug {
fn msg_type_debug(&self) -> &'static str;
}
impl MsgTypeDebug for WsMessage {
fn msg_type_debug(&self) -> &'static str {
match self {
WsMessage::DisplayFrame { .. } => "display_frame",
WsMessage::AudioFrame { .. } => "audio_frame",
WsMessage::AgentInfo { .. } => "agent_info",
WsMessage::Heartbeat => "heartbeat",
WsMessage::HudCommand { .. } => "hud_command",
WsMessage::Resize { .. } => "resize",
WsMessage::FrameBroadcast { .. } => "frame_broadcast",
WsMessage::AudioBroadcast { .. } => "audio_broadcast",
WsMessage::SessionUpdate { .. } => "session_update",
WsMessage::ForwardHudCommand { .. } => "forward_hud_command",
WsMessage::ForwardResize { .. } => "forward_resize",
WsMessage::StreamControl { .. } => "stream_control",
WsMessage::Error { .. } => "error",
WsMessage::Ack { .. } => "ack",
}
}
}