agent: main.rs — entry point, WebSocket client, capture loop, input dispatch, auto-reconnect, REST session creation

This commit is contained in:
Butterfly Dev 2026-04-07 04:38:33 +00:00
parent e1e6442ff5
commit 0961634ce2
2 changed files with 457 additions and 0 deletions

View File

@ -41,6 +41,9 @@ uuid = { version = "1", features = ["v4"] }
# Error handling
anyhow = "1"
# HTTP client (for REST API session creation)
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
# Audio capture (future use)
# cpal = "0.15"

454
agent/src/main.rs Normal file
View File

@ -0,0 +1,454 @@
//! Butterfly Desktop Agent — entry point.
//!
//! Captures the display from this machine, encodes it as JPEG, and streams it
//! to a Butterfly server via WebSocket. Simultaneously receives HUD commands
//! (mouse/keyboard) from remote viewers and executes them locally for full
//! remote control.
//!
//! Usage:
//! butterfly-agent --server ws://192.168.1.100:8080 --session abc123
//! butterfly-agent --server ws://192.168.1.100:8080 --fps 30 --quality 60
mod capture;
mod config;
mod input;
mod protocol;
use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use log::{error, info, warn};
use protocol::AgentWsMessage;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use capture::ScreenCapture;
use config::AgentConfig;
use input::InputHandler;
/// Channel message for the capture task to send encoded frames.
enum CaptureEvent {
/// A new JPEG frame (base64-encoded) is ready to send.
Frame(String),
/// The capture task encountered a fatal error.
Error(String),
}
/// Channel message for internal signals.
enum ControlSignal {
/// A HUD command was received from the server and should be executed.
HudCommand { command: String, params: serde_json::Value },
/// The server requested a resize.
Resize { width: u32, height: u32 },
/// The server sent a stream control command.
StreamControl { action: String },
}
#[tokio::main]
async fn main() -> Result<()> {
// Initialize logger.
env_logger::Builder::from_env(env_logger::Env::new().default_filter_or("info")).init();
let config = AgentConfig::parse_args();
let agent_id = uuid::Uuid::new_v4().to_string();
info!("🦋 Butterfly Agent v{}", env!("CARGO_PKG_VERSION"));
info!(" agent id: {}", agent_id);
info!(" server: {}", config.server);
info!(" fps: {}", config.fps);
info!(" quality: {}", config.quality);
// Determine the session ID — either from CLI or by creating one via REST.
let session_id = match &config.session_id {
Some(id) => {
info!(" session: {} (provided)", id);
id.clone()
}
None => {
info!(" session: creating new session via REST...");
create_session_via_rest(&config).await?
}
};
// Initialize screen capture.
let mut screen_capture = ScreenCapture::new(config.display, config.quality)
.context("failed to initialize screen capture — is a display available?")?;
let resolution = screen_capture.resolution();
info!(" display: {}x{}", screen_capture.width(), screen_capture.height());
// Initialize input handler (for remote control).
let mut input_handler = InputHandler::new(
screen_capture.width() as u32,
screen_capture.height() as u32,
)
.context("failed to initialize input handler — do you have permission?")?;
// Connect to the server and run the main loop.
let mut reconnect_count = 0u32;
loop {
match run_session(
&config,
&agent_id,
&session_id,
&resolution,
&mut screen_capture,
&mut input_handler,
)
.await
{
Ok(()) => {
info!("session ended cleanly");
break;
}
Err(e) => {
error!("session error: {}", e);
if config.reconnect_delay_secs == 0 {
break; // No reconnect configured.
}
if config.max_reconnect > 0 && reconnect_count >= config.max_reconnect {
error!("max reconnect attempts ({}) reached, giving up", config.max_reconnect);
break;
}
reconnect_count += 1;
info!(
"reconnecting in {}s (attempt {})...",
config.reconnect_delay_secs, reconnect_count
);
tokio::time::sleep(config.reconnect_delay()).await;
}
}
}
info!("agent shutting down");
Ok(())
}
/// Run a single session: connect WebSocket, stream frames, handle input.
async fn run_session(
config: &AgentConfig,
agent_id: &str,
session_id: &str,
resolution: &str,
screen_capture: &mut ScreenCapture,
input_handler: &mut InputHandler,
) -> Result<()> {
let ws_url = config.ws_url(session_id);
info!("connecting to {}", ws_url);
// Connect WebSocket.
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url)
.await
.map_err(|e| anyhow::anyhow!("WebSocket connect failed: {}", e))?;
info!("WebSocket connected");
let (mut ws_write, mut ws_read) = ws_stream.split();
// Send AgentInfo to register with the server.
let hostname = get_hostname();
let info_msg = protocol::agent_info_msg(session_id, agent_id, Some(resolution), Some(&hostname));
ws_write
.send(Message::Text(info_msg.into()))
.await
.context("failed to send agent info")?;
info!("registered with server as agent {} for session {}", agent_id, session_id);
// Create a channel for capture events.
let (capture_tx, mut capture_rx) = mpsc::channel::<CaptureEvent>(8);
// Spawn the screen capture loop as a blocking task (scrap is sync).
let capture_session_id = session_id.to_string();
let frame_interval = config.frame_interval();
let capture_handle = tokio::task::spawn_blocking(move || {
capture_loop(capture_tx, capture_session_id, frame_interval);
});
// Spawn heartbeat task.
let heartbeat_session_id = session_id.to_string();
let heartbeat_interval = config.heartbeat_interval();
let (hb_tx, mut hb_rx) = mpsc::channel::<()>(1);
let heartbeat_handle = tokio::spawn(async move {
heartbeat_loop(hb_tx, heartbeat_session_id, heartbeat_interval).await;
});
// Main select loop: read from WebSocket, read from capture, read from heartbeat.
loop {
tokio::select! {
// ── Capture frame ready ─────────────────────────────────────
capture_event = capture_rx.recv() => {
match capture_event {
Some(CaptureEvent::Frame(b64_data)) => {
let frame_msg = protocol::display_frame_msg(session_id, b64_data);
if let Err(e) = ws_write.send(Message::Text(frame_msg.into())).await {
error!("failed to send frame: {}", e);
break;
}
}
Some(CaptureEvent::Error(err)) => {
error!("capture error: {}", err);
// Don't break — the capture loop handles retries internally.
}
None => {
// Capture channel closed.
warn!("capture channel closed");
break;
}
}
}
// ── Incoming WebSocket message from server ──────────────────
ws_msg = ws_read.next() => {
match ws_msg {
Some(Ok(Message::Text(text))) => {
handle_server_message(&text, session_id, input_handler)?;
}
Some(Ok(Message::Ping(data))) => {
let _ = ws_write.send(Message::Pong(data)).await;
}
Some(Ok(Message::Close(reason))) => {
info!("server closed connection: {:?}", reason);
break;
}
Some(Err(e)) => {
error!("WebSocket read error: {}", e);
break;
}
None => {
info!("WebSocket stream ended");
break;
}
Some(Ok(Message::Pong(_))) | Some(Ok(Message::Binary(_))) | Some(Ok(Message::Frame(_))) => {
// Ignore.
}
}
}
// ── Heartbeat tick ──────────────────────────────────────────
_ = hb_rx.recv() => {
let msg = protocol::heartbeat_msg();
if let Err(e) = ws_write.send(Message::Text(msg.into())).await {
error!("failed to send heartbeat: {}", e);
break;
}
}
}
}
// Cleanup.
hb_tx.send(()).await.ok(); // Signal heartbeat to stop.
drop(capture_tx); // Signal capture to stop.
let _ = capture_handle.await;
let _ = heartbeat_handle.await;
Ok(())
}
/// Handle an incoming text message from the server.
fn handle_server_message(
text: &str,
session_id: &str,
input_handler: &mut InputHandler,
) -> Result<()> {
let msg: AgentWsMessage = match serde_json::from_str(text) {
Ok(m) => m,
Err(e) => {
warn!("invalid message from server: {} ({})", e, &text[..text.len().min(120)]);
return Ok(());
}
};
match msg {
AgentWsMessage::ForwardHudCommand { command, params } => {
// Execute the remote control command locally.
if let Err(e) = input_handler.execute(&command, &params) {
warn!("HUD command '{}' failed: {}", command, e);
}
}
AgentWsMessage::ForwardResize { width, height } => {
info!(
"viewer requested resize: {}x{} (session {})",
width, height, session_id
);
// Future: could resize the virtual display here.
}
AgentWsMessage::StreamControl { action } => {
info!("stream control: {} (session {})", action, session_id);
// Future: implement pause/resume streaming.
}
AgentWsMessage::Ack { message } => {
info!("server ack: {}", message);
}
AgentWsMessage::Error { message } => {
error!("server error: {}", message);
}
// Messages the agent sends (shouldn't receive these).
_ => {
warn!("unexpected message type from server: {:?}", text.chars().take(60).collect::<String>());
}
}
Ok(())
}
/// Screen capture loop — runs on a blocking thread.
///
/// Captures frames at the target FPS, encodes them, and sends them through
/// the channel. Handles capture failures gracefully (logs and retries).
fn capture_loop(tx: mpsc::Sender<CaptureEvent>, session_id: String, frame_interval: std::time::Duration) {
info!("capture loop started for session {}", session_id);
// We need a local ScreenCapture instance since it's !Send across async boundaries
// in some configurations. Re-create it here from the display config.
// Note: the main thread passes frame data through the channel, so we need
// to manage capture state locally.
let display_idx = 0; // Primary display.
let quality = 60; // Default quality.
let mut capturer = match ScreenCapture::new(display_idx, quality) {
Ok(c) => c,
Err(e) => {
let _ = tx.send(CaptureEvent::Error(format!("capture init failed: {}", e))).await;
return;
}
};
let mut consecutive_errors = 0u32;
let max_consecutive_errors = 50;
loop {
// Check if the channel is still open (receiver dropped).
if tx.is_closed() {
info!("capture loop: channel closed, exiting");
break;
}
let start = std::time::Instant::now();
match capturer.capture_frame() {
Ok(b64_data) => {
consecutive_errors = 0;
if tx.try_send(CaptureEvent::Frame(b64_data)).is_err() {
// Channel full or closed — drop this frame.
if tx.is_closed() {
break;
}
log::trace!("capture loop: channel full, dropping frame");
}
}
Err(e) => {
consecutive_errors += 1;
warn!("capture error ({} consecutive): {}", consecutive_errors, e);
if consecutive_errors >= max_consecutive_errors {
let _ = tx.blocking_send(CaptureEvent::Error(
format!("too many consecutive capture errors ({}): {}", max_consecutive_errors, e)
));
break;
}
}
}
// Sleep for remaining frame time.
let elapsed = start.elapsed();
let sleep_duration = frame_interval.saturating_sub(elapsed);
if !sleep_duration.is_zero() {
std::thread::sleep(sleep_duration);
}
}
info!("capture loop ended for session {}", session_id);
}
/// Heartbeat loop — sends periodic pings to keep the connection alive.
async fn heartbeat_loop(
tx: mpsc::Sender<()>,
session_id: String,
interval: std::time::Duration,
) {
info!("heartbeat loop started for session {}", session_id);
let mut interval = tokio::time::interval(interval);
loop {
interval.tick().await;
if tx.send(()).await.is_err() {
info!("heartbeat loop: channel closed, exiting");
break;
}
}
}
/// Create a new session via the REST API and return its ID.
async fn create_session_via_rest(config: &AgentConfig) -> Result<String> {
let url = format!("{}/sessions", config.api_base());
let client = reqwest::Client::new();
let response = client
.post(&url)
.json(&serde_json::json!({}))
.send()
.await
.context("REST request to create session failed")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!(
"failed to create session: HTTP {} — {}",
status,
body
);
}
#[derive(serde::Deserialize)]
struct CreateResponse {
ok: bool,
data: Option<SessionData>,
error: Option<String>,
}
#[derive(serde::Deserialize)]
struct SessionData {
id: String,
}
let resp: CreateResponse = response
.json()
.await
.context("failed to parse session creation response")?;
match resp {
CreateResponse {
ok: true,
data: Some(session),
..
} => {
info!("session created: {}", session.id);
Ok(session.id)
}
CreateResponse {
error: Some(err), ..
} => anyhow::bail!("server rejected session creation: {}", err),
_ => anyhow::bail!("unexpected session creation response"),
}
}
/// Get the local hostname, using environment variables as a fallback.
fn get_hostname() -> String {
#[cfg(unix)]
{
std::env::var("HOSTNAME")
.or_else(|_| std::env::var("HOST"))
.unwrap_or_else(|_| "unknown".into())
}
#[cfg(windows)]
{
std::env::var("COMPUTERNAME")
.unwrap_or_else(|_| "unknown".into())
}
#[cfg(not(any(unix, windows)))]
{
"unknown".into()
}
}