diff --git a/server/src/state.rs b/server/src/state.rs index f3f9b76..57d2d8b 100644 --- a/server/src/state.rs +++ b/server/src/state.rs @@ -2,10 +2,20 @@ use std::sync::Arc; use std::time::Instant; use dashmap::DashMap; -use log::info; +use log::{info, warn}; +use tokio::sync::mpsc; use crate::models::{AgentConnection, Session, SessionStatus}; +/// A registered viewer connection with its message channel. +#[derive(Debug)] +pub struct ViewerEntry { + pub viewer_id: String, + /// Channel to send outgoing messages to this viewer. + /// Wrapped in parking_lot::Mutex so it satisfies DashMap's `Send + Sync` bound. + pub sender: parking_lot::Mutex>, +} + /// Global shared state accessible from every request / WS handler via `web::Data`. #[derive(Debug)] pub struct AppState { @@ -23,6 +33,10 @@ pub struct AppState { pub frame_buffer_size: usize, /// Per-session frame buffers (session_id → ring of latest base64 frames). pub frame_buffers: DashMap>, + /// Per-session viewer registries (session_id → list of connected viewers). + pub viewers: DashMap>, + /// Per-session agent message channels (session_id → mpsc sender to agent WS task). + pub agent_channels: DashMap>>, } /// Simple circular buffer that keeps the *N* most recent display frames. @@ -68,6 +82,8 @@ impl AppState { idle_timeout_secs, frame_buffer_size, frame_buffers: DashMap::new(), + viewers: DashMap::new(), + agent_channels: DashMap::new(), }) } @@ -89,8 +105,21 @@ impl AppState { self.sessions.get(id).map(|r| r.value().clone()) } - /// Remove a session (and its frame buffer). + /// Remove a session (and its frame buffer, viewers, agent channel). pub fn remove_session(&self, id: &str) -> bool { + // Notify all viewers that the session is gone. + if let Some(viewers) = self.viewers.get(id) { + for viewer in viewers.iter() { + let msg = serde_json::json!({"msg_type": "error", "message": "session deleted"}) + .to_string(); + let _ = viewer.sender.lock().send(msg).await; + } + } + self.viewers.remove(id); + // Close the agent channel so its writer task exits. + if let Some((_, ch)) = self.agent_channels.remove(id) { + drop(ch); // Drop the Mutex → Sender is dropped → Receiver gets None. + } let removed = self.sessions.remove(id).is_some(); self.frame_buffers.remove(id); if removed { @@ -118,7 +147,6 @@ impl AppState { /// Unregister an agent and mark its session as disconnected. pub fn unregister_agent(&self, agent_id: &str) { if let Some((_, agent)) = self.agents.remove(agent_id) { - // Mark the session as disconnected if no other agents are connected for it. let session_id = agent.session_id.clone(); let still_has_agent = self .agents @@ -142,13 +170,94 @@ impl AppState { } } - /// Return counts for health-check. - pub fn stats(&self) -> (usize, usize) { - let _active: usize = self + /// Get the latest frame from the session's buffer (for new viewer catch-up). + pub fn get_latest_frame(&self, session_id: &str) -> Option { + self.frame_buffers.get(session_id).and_then(|buf| buf.latest()) + } + + // ── Viewer channel management ─────────────────────────────────────────── + + /// Register a new viewer for a session. + pub fn register_viewer(&self, session_id: &str, viewer_id: &str, sender: mpsc::Sender) { + let entry = ViewerEntry { + viewer_id: viewer_id.to_string(), + sender: parking_lot::Mutex::new(sender), + }; + if let Some(mut viewers) = self.viewers.get_mut(session_id) { + viewers.push(entry); + } else { + self.viewers.insert(session_id.to_string(), vec![entry]); + } + info!("viewer registered: {} for session {}", viewer_id, session_id); + } + + /// Unregister a viewer (called on disconnect). + pub fn unregister_viewer(&self, session_id: &str, viewer_id: &str) { + if let Some(mut viewers) = self.viewers.get_mut(session_id) { + viewers.retain(|v| v.viewer_id != viewer_id); + if viewers.is_empty() { + drop(viewers); + self.viewers.remove(session_id); + } + } + info!("viewer unregistered: {} from session {}", viewer_id, session_id); + } + + // ── Agent channel management ──────────────────────────────────────────── + + /// Register the agent's message channel for a session. + pub fn register_agent_channel(&self, session_id: &str, sender: mpsc::Sender) { + self.agent_channels + .insert(session_id.to_string(), parking_lot::Mutex::new(sender)); + info!("agent channel registered for session {}", session_id); + } + + /// Unregister the agent's message channel. + pub fn unregister_agent_channel(&self, session_id: &str) { + if let Some((_, ch)) = self.agent_channels.remove(session_id) { + drop(ch); + } + info!("agent channel unregistered for session {}", session_id); + } + + // ── Broadcast / Forward ──────────────────────────────────────────────── + + /// Broadcast a JSON message to all viewers connected to a session. + pub async fn broadcast_to_viewers(&self, session_id: &str, json_msg: &str) { + if let Some(viewers) = self.viewers.get(session_id) { + for viewer in viewers.iter() { + if let Err(e) = viewer.sender.lock().send(json_msg.to_string()).await { + warn!("[ws] failed to send to viewer {}: {}", viewer.viewer_id, e); + } + } + } + } + + /// Send a JSON message to the agent connected to a session. + /// Returns false if no agent channel exists. + pub async fn send_to_agent(&self, session_id: &str, json_msg: &str) -> bool { + if let Some(ch) = self.agent_channels.get(session_id) { + match ch.lock().send(json_msg.to_string()).await { + Ok(()) => true, + Err(e) => { + warn!("[ws] failed to send to agent in session {}: {}", session_id, e); + false + } + } + } else { + false + } + } + + /// Return counts for health-check: (active_sessions, connected_agents, connected_viewers). + pub fn stats(&self) -> (usize, usize, usize) { + let active: usize = self .sessions .iter() .filter(|r| r.status == SessionStatus::Active) .count(); - (self.sessions.len(), self.agents.len()) + let agent_count = self.agents.len(); + let viewer_count: usize = self.viewers.iter().map(|r| r.value().len()).sum(); + (active, agent_count, viewer_count) } }