state: add viewer/agent channel registries, broadcast/forward methods, fix stats()

This commit is contained in:
Butterfly Dev 2026-04-07 03:57:00 +00:00
parent fc15cadf6a
commit 29eda76675

View File

@ -2,10 +2,20 @@ use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap;
use log::info;
use log::{info, warn};
use tokio::sync::mpsc;
use crate::models::{AgentConnection, Session, SessionStatus};
/// A registered viewer connection with its message channel.
#[derive(Debug)]
pub struct ViewerEntry {
pub viewer_id: String,
/// Channel to send outgoing messages to this viewer.
/// Wrapped in parking_lot::Mutex so it satisfies DashMap's `Send + Sync` bound.
pub sender: parking_lot::Mutex<mpsc::Sender<String>>,
}
/// Global shared state accessible from every request / WS handler via `web::Data`.
#[derive(Debug)]
pub struct AppState {
@ -23,6 +33,10 @@ pub struct AppState {
pub frame_buffer_size: usize,
/// Per-session frame buffers (session_id → ring of latest base64 frames).
pub frame_buffers: DashMap<String, Arc<FrameBuffer>>,
/// Per-session viewer registries (session_id → list of connected viewers).
pub viewers: DashMap<String, Vec<ViewerEntry>>,
/// Per-session agent message channels (session_id → mpsc sender to agent WS task).
pub agent_channels: DashMap<String, parking_lot::Mutex<mpsc::Sender<String>>>,
}
/// Simple circular buffer that keeps the *N* most recent display frames.
@ -68,6 +82,8 @@ impl AppState {
idle_timeout_secs,
frame_buffer_size,
frame_buffers: DashMap::new(),
viewers: DashMap::new(),
agent_channels: DashMap::new(),
})
}
@ -89,8 +105,21 @@ impl AppState {
self.sessions.get(id).map(|r| r.value().clone())
}
/// Remove a session (and its frame buffer).
/// Remove a session (and its frame buffer, viewers, agent channel).
pub fn remove_session(&self, id: &str) -> bool {
// Notify all viewers that the session is gone.
if let Some(viewers) = self.viewers.get(id) {
for viewer in viewers.iter() {
let msg = serde_json::json!({"msg_type": "error", "message": "session deleted"})
.to_string();
let _ = viewer.sender.lock().send(msg).await;
}
}
self.viewers.remove(id);
// Close the agent channel so its writer task exits.
if let Some((_, ch)) = self.agent_channels.remove(id) {
drop(ch); // Drop the Mutex<Sender> → Sender is dropped → Receiver gets None.
}
let removed = self.sessions.remove(id).is_some();
self.frame_buffers.remove(id);
if removed {
@ -118,7 +147,6 @@ impl AppState {
/// Unregister an agent and mark its session as disconnected.
pub fn unregister_agent(&self, agent_id: &str) {
if let Some((_, agent)) = self.agents.remove(agent_id) {
// Mark the session as disconnected if no other agents are connected for it.
let session_id = agent.session_id.clone();
let still_has_agent = self
.agents
@ -142,13 +170,94 @@ impl AppState {
}
}
/// Return counts for health-check.
pub fn stats(&self) -> (usize, usize) {
let _active: usize = self
/// Get the latest frame from the session's buffer (for new viewer catch-up).
pub fn get_latest_frame(&self, session_id: &str) -> Option<String> {
self.frame_buffers.get(session_id).and_then(|buf| buf.latest())
}
// ── Viewer channel management ───────────────────────────────────────────
/// Register a new viewer for a session.
pub fn register_viewer(&self, session_id: &str, viewer_id: &str, sender: mpsc::Sender<String>) {
let entry = ViewerEntry {
viewer_id: viewer_id.to_string(),
sender: parking_lot::Mutex::new(sender),
};
if let Some(mut viewers) = self.viewers.get_mut(session_id) {
viewers.push(entry);
} else {
self.viewers.insert(session_id.to_string(), vec![entry]);
}
info!("viewer registered: {} for session {}", viewer_id, session_id);
}
/// Unregister a viewer (called on disconnect).
pub fn unregister_viewer(&self, session_id: &str, viewer_id: &str) {
if let Some(mut viewers) = self.viewers.get_mut(session_id) {
viewers.retain(|v| v.viewer_id != viewer_id);
if viewers.is_empty() {
drop(viewers);
self.viewers.remove(session_id);
}
}
info!("viewer unregistered: {} from session {}", viewer_id, session_id);
}
// ── Agent channel management ────────────────────────────────────────────
/// Register the agent's message channel for a session.
pub fn register_agent_channel(&self, session_id: &str, sender: mpsc::Sender<String>) {
self.agent_channels
.insert(session_id.to_string(), parking_lot::Mutex::new(sender));
info!("agent channel registered for session {}", session_id);
}
/// Unregister the agent's message channel.
pub fn unregister_agent_channel(&self, session_id: &str) {
if let Some((_, ch)) = self.agent_channels.remove(session_id) {
drop(ch);
}
info!("agent channel unregistered for session {}", session_id);
}
// ── Broadcast / Forward ────────────────────────────────────────────────
/// Broadcast a JSON message to all viewers connected to a session.
pub async fn broadcast_to_viewers(&self, session_id: &str, json_msg: &str) {
if let Some(viewers) = self.viewers.get(session_id) {
for viewer in viewers.iter() {
if let Err(e) = viewer.sender.lock().send(json_msg.to_string()).await {
warn!("[ws] failed to send to viewer {}: {}", viewer.viewer_id, e);
}
}
}
}
/// Send a JSON message to the agent connected to a session.
/// Returns false if no agent channel exists.
pub async fn send_to_agent(&self, session_id: &str, json_msg: &str) -> bool {
if let Some(ch) = self.agent_channels.get(session_id) {
match ch.lock().send(json_msg.to_string()).await {
Ok(()) => true,
Err(e) => {
warn!("[ws] failed to send to agent in session {}: {}", session_id, e);
false
}
}
} else {
false
}
}
/// Return counts for health-check: (active_sessions, connected_agents, connected_viewers).
pub fn stats(&self) -> (usize, usize, usize) {
let active: usize = self
.sessions
.iter()
.filter(|r| r.status == SessionStatus::Active)
.count();
(self.sessions.len(), self.agents.len())
let agent_count = self.agents.len();
let viewer_count: usize = self.viewers.iter().map(|r| r.value().len()).sum();
(active, agent_count, viewer_count)
}
}