diff --git a/agent/src/main.rs b/agent/src/main.rs index 598082f..ac9429c 100644 --- a/agent/src/main.rs +++ b/agent/src/main.rs @@ -1,23 +1,28 @@ //! Butterfly Desktop Agent — entry point. //! -//! Captures the display from this machine, encodes it as JPEG, and streams it -//! to a Butterfly server via WebSocket. Simultaneously receives HUD commands -//! (mouse/keyboard) from remote viewers and executes them locally for full -//! remote control. +//! Captures the display from this machine, encodes it as H.264 or JPEG, and streams +//! it to a Butterfly server via WebSocket binary frames. Simultaneously receives +//! HUD commands (mouse/keyboard) from remote viewers and executes them locally. +//! +//! Wire protocol: +//! - Binary WebSocket frames = raw video (H.264 NALs or JPEG) with 13-byte header +//! - Text WebSocket frames = JSON control messages (HUD, heartbeat, etc.) //! //! Usage: -//! butterfly-agent --server ws://192.168.1.100:8080 --session abc123 -//! butterfly-agent --server ws://192.168.1.100:8080 --fps 30 --quality 60 +//! butterfly-agent --server ws://192.168.1.100:8080 --encoder h264 +//! butterfly-agent --server ws://192.168.1.100:8080 --encoder jpeg --fps 30 mod capture; mod config; +mod encoder; mod input; mod protocol; use anyhow::{Context, Result}; +use encoder::{EncodedFrame, EncoderType}; use futures_util::{SinkExt, StreamExt}; use log::{error, info, warn}; -use protocol::AgentWsMessage; +use protocol::{ControlMessage, FRAME_HEADER_SIZE}; use tokio::sync::mpsc; use tokio_tungstenite::tungstenite::Message; @@ -25,27 +30,16 @@ use capture::ScreenCapture; use config::AgentConfig; use input::InputHandler; -/// Channel message for the capture task to send encoded frames. +/// Event sent from the capture thread to the main loop. enum CaptureEvent { - /// A new JPEG frame (base64-encoded) is ready to send. - Frame(String), - /// The capture task encountered a fatal error. + /// A binary video frame ready to send (complete with header). + BinaryFrame(Vec), + /// The capture thread hit a fatal error. Error(String), } -/// Channel message for internal signals. -enum ControlSignal { - /// A HUD command was received from the server and should be executed. - HudCommand { command: String, params: serde_json::Value }, - /// The server requested a resize. - Resize { width: u32, height: u32 }, - /// The server sent a stream control command. - StreamControl { action: String }, -} - #[tokio::main] async fn main() -> Result<()> { - // Initialize logger. env_logger::Builder::from_env(env_logger::Env::new().default_filter_or("info")).init(); let config = AgentConfig::parse_args(); @@ -54,10 +48,15 @@ async fn main() -> Result<()> { info!("🦋 Butterfly Agent v{}", env!("CARGO_PKG_VERSION")); info!(" agent id: {}", agent_id); info!(" server: {}", config.server); + info!(" encoder: {}", config.encoder); info!(" fps: {}", config.fps); info!(" quality: {}", config.quality); - // Determine the session ID — either from CLI or by creating one via REST. + // Parse encoder type. + let encoder_type: EncoderType = config.encoder.parse() + .map_err(|e| anyhow::anyhow!("{}", e))?; + + // Determine session ID. let session_id = match &config.session_id { Some(id) => { info!(" session: {} (provided)", id); @@ -69,20 +68,28 @@ async fn main() -> Result<()> { } }; - // Initialize screen capture. - let mut screen_capture = ScreenCapture::new(config.display, config.quality) - .context("failed to initialize screen capture — is a display available?")?; + // Initialize screen capture (raw BGRA output). + let screen_capture = ScreenCapture::new(config.display) + .context("failed to initialize screen capture")?; let resolution = screen_capture.resolution(); info!(" display: {}x{}", screen_capture.width(), screen_capture.height()); + // Initialize video encoder. + let mut video_encoder = encoder::create_encoder( + encoder_type, + screen_capture.width(), + screen_capture.height(), + config.quality, + ).context("failed to create encoder")?; + info!(" encoder: {:?} ready", encoder_type); + // Initialize input handler (for remote control). let mut input_handler = InputHandler::new( screen_capture.width() as u32, screen_capture.height() as u32, - ) - .context("failed to initialize input handler — do you have permission?")?; + ).context("failed to initialize input handler")?; - // Connect to the server and run the main loop. + // Run the main connection loop with auto-reconnect. let mut reconnect_count = 0u32; loop { match run_session( @@ -90,31 +97,25 @@ async fn main() -> Result<()> { &agent_id, &session_id, &resolution, - &mut screen_capture, + screen_capture.width(), + screen_capture.height(), + &encoder_type, + &mut video_encoder, &mut input_handler, - ) - .await - { + ).await { Ok(()) => { info!("session ended cleanly"); break; } Err(e) => { error!("session error: {}", e); - - if config.reconnect_delay_secs == 0 { - break; // No reconnect configured. - } + if config.reconnect_delay_secs == 0 { break; } if config.max_reconnect > 0 && reconnect_count >= config.max_reconnect { - error!("max reconnect attempts ({}) reached, giving up", config.max_reconnect); + error!("max reconnect attempts ({}) reached", config.max_reconnect); break; } - reconnect_count += 1; - info!( - "reconnecting in {}s (attempt {})...", - config.reconnect_delay_secs, reconnect_count - ); + info!("reconnecting in {}s (attempt {})...", config.reconnect_delay_secs, reconnect_count); tokio::time::sleep(config.reconnect_delay()).await; } } @@ -124,19 +125,22 @@ async fn main() -> Result<()> { Ok(()) } -/// Run a single session: connect WebSocket, stream frames, handle input. +/// Run a single session: connect, stream video, handle input. +#[allow(clippy::too_many_arguments)] async fn run_session( config: &AgentConfig, agent_id: &str, session_id: &str, resolution: &str, - screen_capture: &mut ScreenCapture, + capture_width: usize, + capture_height: usize, + encoder_type: &EncoderType, + video_encoder: &mut Box, input_handler: &mut InputHandler, ) -> Result<()> { let ws_url = config.ws_url(session_id); info!("connecting to {}", ws_url); - // Connect WebSocket. let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url) .await .map_err(|e| anyhow::anyhow!("WebSocket connect failed: {}", e))?; @@ -144,92 +148,112 @@ async fn run_session( info!("WebSocket connected"); let (mut ws_write, mut ws_read) = ws_stream.split(); - // Send AgentInfo to register with the server. + // Send AgentInfo (text/JSON). let hostname = get_hostname(); - - let info_msg = protocol::agent_info_msg(session_id, agent_id, Some(resolution), Some(&hostname)); - ws_write - .send(Message::Text(info_msg.into())) - .await + let encoder_name = match encoder_type { + EncoderType::H264 => "h264", + EncoderType::Jpeg => "jpeg", + }; + let info_msg = protocol::agent_info_msg( + session_id, + agent_id, + Some(resolution), + Some(&hostname), + Some(encoder_name), + ); + ws_write.send(Message::Text(info_msg.into())).await .context("failed to send agent info")?; - info!("registered with server as agent {} for session {}", agent_id, session_id); + info!("registered: agent {} session {}", agent_id, session_id); - // Create a channel for capture events. - let (capture_tx, mut capture_rx) = mpsc::channel::(8); + // Channel for capture events. + let (capture_tx, mut capture_rx) = mpsc::channel::(16); - // Spawn the screen capture loop as a blocking task (scrap is sync). - let capture_session_id = session_id.to_string(); - let frame_interval = config.frame_interval(); + // Spawn capture + encode loop (blocking thread). + let cap_session_id = session_id.to_string(); + let cap_encoder_type = *encoder_type; + let cap_quality = config.quality; + let cap_frame_interval = config.frame_interval(); + let cap_width = capture_width; + let cap_height = capture_height; let capture_handle = tokio::task::spawn_blocking(move || { - capture_loop(capture_tx, capture_session_id, frame_interval); + capture_encode_loop( + capture_tx, + cap_session_id, + cap_encoder_type, + cap_quality, + cap_frame_interval, + cap_width, + cap_height, + ); }); - // Spawn heartbeat task. - let heartbeat_session_id = session_id.to_string(); - let heartbeat_interval = config.heartbeat_interval(); + // Spawn heartbeat. + let hb_session_id = session_id.to_string(); + let hb_interval = config.heartbeat_interval(); let (hb_tx, mut hb_rx) = mpsc::channel::<()>(1); let heartbeat_handle = tokio::spawn(async move { - heartbeat_loop(hb_tx, heartbeat_session_id, heartbeat_interval).await; + let mut interval = tokio::time::interval(hb_interval); + loop { + interval.tick().await; + if hb_tx.send(()).await.is_err() { break; } + } }); - // Main select loop: read from WebSocket, read from capture, read from heartbeat. + // Main select loop. + let mut start_time = std::time::Instant::now(); loop { tokio::select! { - // ── Capture frame ready ───────────────────────────────────── + // ── Encoded frame from capture thread ────────────────────── capture_event = capture_rx.recv() => { match capture_event { - Some(CaptureEvent::Frame(b64_data)) => { - let frame_msg = protocol::display_frame_msg(session_id, b64_data); - if let Err(e) = ws_write.send(Message::Text(frame_msg.into())).await { - error!("failed to send frame: {}", e); + Some(CaptureEvent::BinaryFrame(data)) => { + if let Err(e) = ws_write.send(Message::Binary(data.into())).await { + error!("failed to send binary frame: {}", e); break; } } Some(CaptureEvent::Error(err)) => { error!("capture error: {}", err); - // Don't break — the capture loop handles retries internally. } None => { - // Capture channel closed. warn!("capture channel closed"); break; } } } - // ── Incoming WebSocket message from server ────────────────── + // ── Incoming message from server ────────────────────────── ws_msg = ws_read.next() => { match ws_msg { Some(Ok(Message::Text(text))) => { - handle_server_message(&text, session_id, input_handler)?; + // JSON control message (HUD command, heartbeat ack, etc.). + handle_server_text(&text, input_handler)?; } Some(Ok(Message::Ping(data))) => { let _ = ws_write.send(Message::Pong(data)).await; } Some(Ok(Message::Close(reason))) => { - info!("server closed connection: {:?}", reason); + info!("server closed: {:?}", reason); break; } Some(Err(e)) => { - error!("WebSocket read error: {}", e); + error!("WS read error: {}", e); break; } None => { - info!("WebSocket stream ended"); + info!("WS stream ended"); break; } - Some(Ok(Message::Pong(_))) | Some(Ok(Message::Binary(_))) | Some(Ok(Message::Frame(_))) => { - // Ignore. - } + Some(Ok(_)) => { /* Ignore Pong, Binary, Frame */ } } } - // ── Heartbeat tick ────────────────────────────────────────── + // ── Heartbeat tick ──────────────────────────────────────── _ = hb_rx.recv() => { let msg = protocol::heartbeat_msg(); if let Err(e) = ws_write.send(Message::Text(msg.into())).await { - error!("failed to send heartbeat: {}", e); + error!("heartbeat failed: {}", e); break; } } @@ -237,218 +261,185 @@ async fn run_session( } // Cleanup. - hb_tx.send(()).await.ok(); // Signal heartbeat to stop. - drop(capture_tx); // Signal capture to stop. + let _ = hb_tx.send(()).await; + drop(capture_tx); let _ = capture_handle.await; let _ = heartbeat_handle.await; Ok(()) } -/// Handle an incoming text message from the server. -fn handle_server_message( - text: &str, - session_id: &str, - input_handler: &mut InputHandler, -) -> Result<()> { - let msg: AgentWsMessage = match serde_json::from_str(text) { +/// Handle a JSON text message from the server. +fn handle_server_text(text: &str, input_handler: &mut InputHandler) -> Result<()> { + let msg: ControlMessage = match serde_json::from_str(text) { Ok(m) => m, Err(e) => { - warn!("invalid message from server: {} ({})", e, &text[..text.len().min(120)]); + warn!("invalid JSON from server: {} ({})", e, &text[..text.len().min(100)]); return Ok(()); } }; match msg { - AgentWsMessage::ForwardHudCommand { command, params } => { - // Execute the remote control command locally. + ControlMessage::ForwardHudCommand { command, params } => { if let Err(e) = input_handler.execute(&command, ¶ms) { warn!("HUD command '{}' failed: {}", command, e); } } - AgentWsMessage::ForwardResize { width, height } => { - info!( - "viewer requested resize: {}x{} (session {})", - width, height, session_id - ); - // Future: could resize the virtual display here. + ControlMessage::ForwardResize { width, height } => { + info!("resize request: {}x{}", width, height); } - AgentWsMessage::StreamControl { action } => { - info!("stream control: {} (session {})", action, session_id); - // Future: implement pause/resume streaming. + ControlMessage::StreamControl { action } => { + info!("stream control: {}", action); } - AgentWsMessage::Ack { message } => { - info!("server ack: {}", message); + ControlMessage::Ack { message } => { + info!("ack: {}", message); } - AgentWsMessage::Error { message } => { + ControlMessage::Error { message } => { error!("server error: {}", message); } - // Messages the agent sends (shouldn't receive these). _ => { - warn!("unexpected message type from server: {:?}", text.chars().take(60).collect::()); + warn!("unexpected message: {:?}", text.chars().take(60).collect::()); } } - Ok(()) } -/// Screen capture loop — runs on a blocking thread. +/// Capture + encode loop — runs on a blocking thread. /// -/// Captures frames at the target FPS, encodes them, and sends them through -/// the channel. Handles capture failures gracefully (logs and retries). -fn capture_loop(tx: mpsc::Sender, session_id: String, frame_interval: std::time::Duration) { - info!("capture loop started for session {}", session_id); +/// Captures raw BGRA frames, encodes them (H.264 or JPEG), and sends +/// complete binary frames (with header) through the channel. +fn capture_encode_loop( + tx: mpsc::Sender, + session_id: String, + encoder_type: EncoderType, + quality: u8, + frame_interval: std::time::Duration, + width: usize, + height: usize, +) { + info!("capture+encode loop started for session {} (encoder: {:?})", session_id, encoder_type); - // We need a local ScreenCapture instance since it's !Send across async boundaries - // in some configurations. Re-create it here from the display config. - // Note: the main thread passes frame data through the channel, so we need - // to manage capture state locally. - let display_idx = 0; // Primary display. - let quality = 60; // Default quality. - - let mut capturer = match ScreenCapture::new(display_idx, quality) { + // Create a local capturer and encoder for this thread. + let mut capturer = match ScreenCapture::new(0) { Ok(c) => c, Err(e) => { - let _ = tx.send(CaptureEvent::Error(format!("capture init failed: {}", e))).await; + let _ = tx.blocking_send(CaptureEvent::Error(format!("capture init failed: {}", e))); + return; + } + }; + + let mut encoder = match encoder::create_encoder(encoder_type, width, height, quality) { + Ok(enc) => enc, + Err(e) => { + let _ = tx.blocking_send(CaptureEvent::Error(format!("encoder init failed: {}", e))); return; } }; let mut consecutive_errors = 0u32; - let max_consecutive_errors = 50; + let start_time = std::time::Instant::now(); + let mut frame_count = 0u64; loop { - // Check if the channel is still open (receiver dropped). if tx.is_closed() { - info!("capture loop: channel closed, exiting"); + info!("capture+encode loop: channel closed"); break; } - let start = std::time::Instant::now(); + let loop_start = std::time::Instant::now(); - match capturer.capture_frame() { - Ok(b64_data) => { - consecutive_errors = 0; - if tx.try_send(CaptureEvent::Frame(b64_data)).is_err() { - // Channel full or closed — drop this frame. - if tx.is_closed() { - break; - } - log::trace!("capture loop: channel full, dropping frame"); - } - } + // Capture raw BGRA frame. + let raw = match capturer.capture_raw() { + Ok(r) => r, Err(e) => { consecutive_errors += 1; - warn!("capture error ({} consecutive): {}", consecutive_errors, e); - - if consecutive_errors >= max_consecutive_errors { + if consecutive_errors >= 50 { let _ = tx.blocking_send(CaptureEvent::Error( - format!("too many consecutive capture errors ({}): {}", max_consecutive_errors, e) + format!("too many capture errors: {}", e) )); break; } + warn!("capture error ({}): {}", consecutive_errors, e); + std::thread::sleep(frame_interval); + continue; + } + }; + consecutive_errors = 0; + + // Encode the frame. + match encoder.encode_bgra(&raw.bgra, raw.width, raw.height) { + Ok(encoded) => { + frame_count += 1; + let timestamp_ms = start_time.elapsed().as_millis() as u32; + + // Build binary frame: 13-byte header + payload. + let binary_frame = protocol::build_binary_frame( + encoded.frame_type, + timestamp_ms, + raw.width as u32, + raw.height as u32, + &encoded.payload, + ); + + if tx.try_send(CaptureEvent::BinaryFrame(binary_frame)).is_err() { + if tx.is_closed() { break; } + log::trace!("channel full, dropping frame"); + } + } + Err(e) => { + warn!("encode error: {}", e); } } - // Sleep for remaining frame time. - let elapsed = start.elapsed(); - let sleep_duration = frame_interval.saturating_sub(elapsed); - if !sleep_duration.is_zero() { - std::thread::sleep(sleep_duration); + // Frame rate control. + let elapsed = loop_start.elapsed(); + let sleep = frame_interval.saturating_sub(elapsed); + if !sleep.is_zero() { + std::thread::sleep(sleep); } } - info!("capture loop ended for session {}", session_id); + let total = start_time.elapsed(); + info!( + "capture+encode loop ended: {} frames in {:.1}s ({:.0} fps avg)", + frame_count, + total.as_secs_f64(), + if total.as_secs_f64() > 0.0 { frame_count as f64 / total.as_secs_f64() } else { 0.0 } + ); } -/// Heartbeat loop — sends periodic pings to keep the connection alive. -async fn heartbeat_loop( - tx: mpsc::Sender<()>, - session_id: String, - interval: std::time::Duration, -) { - info!("heartbeat loop started for session {}", session_id); - let mut interval = tokio::time::interval(interval); - - loop { - interval.tick().await; - if tx.send(()).await.is_err() { - info!("heartbeat loop: channel closed, exiting"); - break; - } - } -} - -/// Create a new session via the REST API and return its ID. -async fn create_session_via_rest(config: &AgentConfig) -> Result { - let url = format!("{}/sessions", config.api_base()); - - let client = reqwest::Client::new(); - let response = client - .post(&url) - .json(&serde_json::json!({})) - .send() - .await - .context("REST request to create session failed")?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!( - "failed to create session: HTTP {} — {}", - status, - body - ); - } - - #[derive(serde::Deserialize)] - struct CreateResponse { - ok: bool, - data: Option, - error: Option, - } - - #[derive(serde::Deserialize)] - struct SessionData { - id: String, - } - - let resp: CreateResponse = response - .json() - .await - .context("failed to parse session creation response")?; - - match resp { - CreateResponse { - ok: true, - data: Some(session), - .. - } => { - info!("session created: {}", session.id); - Ok(session.id) - } - CreateResponse { - error: Some(err), .. - } => anyhow::bail!("server rejected session creation: {}", err), - _ => anyhow::bail!("unexpected session creation response"), - } -} - -/// Get the local hostname, using environment variables as a fallback. +/// Get the local hostname. fn get_hostname() -> String { #[cfg(unix)] - { - std::env::var("HOSTNAME") - .or_else(|_| std::env::var("HOST")) - .unwrap_or_else(|_| "unknown".into()) - } + { std::env::var("HOSTNAME").or_else(|_| std::env::var("HOST")).unwrap_or_else(|_| "unknown".into()) } #[cfg(windows)] - { - std::env::var("COMPUTERNAME") - .unwrap_or_else(|_| "unknown".into()) - } + { std::env::var("COMPUTERNAME").unwrap_or_else(|_| "unknown".into()) } #[cfg(not(any(unix, windows)))] - { - "unknown".into() + { "unknown".into() } +} + +/// Create a new session via REST API. +async fn create_session_via_rest(config: &AgentConfig) -> Result { + let url = format!("{}/sessions", config.api_base()); + let client = reqwest::Client::new(); + + let response = client.post(&url).json(&serde_json::json!({})).send().await + .context("REST session creation failed")?; + + if !response.status().is_success() { + let body = response.text().await.unwrap_or_default(); + anyhow::bail!("session creation failed: {}", body); + } + + #[derive(serde::Deserialize)] + struct Resp { ok: bool, data: Option, error: Option } + #[derive(serde::Deserialize)] + struct SessionData { id: String } + + let resp: Resp = response.json().await.context("parse response failed")?; + match resp { + Resp { ok: true, data: Some(s), .. } => { info!("session created: {}", s.id); Ok(s.id) } + Resp { error: Some(e), .. } => anyhow::bail!("server error: {}", e), + _ => anyhow::bail!("unexpected response"), } }