agent: main.rs — entry point, WebSocket client, capture loop, input dispatch, auto-reconnect, REST session creation
This commit is contained in:
parent
e1e6442ff5
commit
0961634ce2
@ -41,6 +41,9 @@ uuid = { version = "1", features = ["v4"] }
|
||||
# Error handling
|
||||
anyhow = "1"
|
||||
|
||||
# HTTP client (for REST API session creation)
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||
|
||||
# Audio capture (future use)
|
||||
# cpal = "0.15"
|
||||
|
||||
|
||||
454
agent/src/main.rs
Normal file
454
agent/src/main.rs
Normal file
@ -0,0 +1,454 @@
|
||||
//! Butterfly Desktop Agent — entry point.
|
||||
//!
|
||||
//! Captures the display from this machine, encodes it as JPEG, and streams it
|
||||
//! to a Butterfly server via WebSocket. Simultaneously receives HUD commands
|
||||
//! (mouse/keyboard) from remote viewers and executes them locally for full
|
||||
//! remote control.
|
||||
//!
|
||||
//! Usage:
|
||||
//! butterfly-agent --server ws://192.168.1.100:8080 --session abc123
|
||||
//! butterfly-agent --server ws://192.168.1.100:8080 --fps 30 --quality 60
|
||||
|
||||
mod capture;
|
||||
mod config;
|
||||
mod input;
|
||||
mod protocol;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use log::{error, info, warn};
|
||||
use protocol::AgentWsMessage;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
|
||||
use capture::ScreenCapture;
|
||||
use config::AgentConfig;
|
||||
use input::InputHandler;
|
||||
|
||||
/// Channel message for the capture task to send encoded frames.
|
||||
enum CaptureEvent {
|
||||
/// A new JPEG frame (base64-encoded) is ready to send.
|
||||
Frame(String),
|
||||
/// The capture task encountered a fatal error.
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Channel message for internal signals.
|
||||
enum ControlSignal {
|
||||
/// A HUD command was received from the server and should be executed.
|
||||
HudCommand { command: String, params: serde_json::Value },
|
||||
/// The server requested a resize.
|
||||
Resize { width: u32, height: u32 },
|
||||
/// The server sent a stream control command.
|
||||
StreamControl { action: String },
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize logger.
|
||||
env_logger::Builder::from_env(env_logger::Env::new().default_filter_or("info")).init();
|
||||
|
||||
let config = AgentConfig::parse_args();
|
||||
let agent_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
info!("🦋 Butterfly Agent v{}", env!("CARGO_PKG_VERSION"));
|
||||
info!(" agent id: {}", agent_id);
|
||||
info!(" server: {}", config.server);
|
||||
info!(" fps: {}", config.fps);
|
||||
info!(" quality: {}", config.quality);
|
||||
|
||||
// Determine the session ID — either from CLI or by creating one via REST.
|
||||
let session_id = match &config.session_id {
|
||||
Some(id) => {
|
||||
info!(" session: {} (provided)", id);
|
||||
id.clone()
|
||||
}
|
||||
None => {
|
||||
info!(" session: creating new session via REST...");
|
||||
create_session_via_rest(&config).await?
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize screen capture.
|
||||
let mut screen_capture = ScreenCapture::new(config.display, config.quality)
|
||||
.context("failed to initialize screen capture — is a display available?")?;
|
||||
let resolution = screen_capture.resolution();
|
||||
info!(" display: {}x{}", screen_capture.width(), screen_capture.height());
|
||||
|
||||
// Initialize input handler (for remote control).
|
||||
let mut input_handler = InputHandler::new(
|
||||
screen_capture.width() as u32,
|
||||
screen_capture.height() as u32,
|
||||
)
|
||||
.context("failed to initialize input handler — do you have permission?")?;
|
||||
|
||||
// Connect to the server and run the main loop.
|
||||
let mut reconnect_count = 0u32;
|
||||
loop {
|
||||
match run_session(
|
||||
&config,
|
||||
&agent_id,
|
||||
&session_id,
|
||||
&resolution,
|
||||
&mut screen_capture,
|
||||
&mut input_handler,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
info!("session ended cleanly");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("session error: {}", e);
|
||||
|
||||
if config.reconnect_delay_secs == 0 {
|
||||
break; // No reconnect configured.
|
||||
}
|
||||
if config.max_reconnect > 0 && reconnect_count >= config.max_reconnect {
|
||||
error!("max reconnect attempts ({}) reached, giving up", config.max_reconnect);
|
||||
break;
|
||||
}
|
||||
|
||||
reconnect_count += 1;
|
||||
info!(
|
||||
"reconnecting in {}s (attempt {})...",
|
||||
config.reconnect_delay_secs, reconnect_count
|
||||
);
|
||||
tokio::time::sleep(config.reconnect_delay()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("agent shutting down");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run a single session: connect WebSocket, stream frames, handle input.
|
||||
async fn run_session(
|
||||
config: &AgentConfig,
|
||||
agent_id: &str,
|
||||
session_id: &str,
|
||||
resolution: &str,
|
||||
screen_capture: &mut ScreenCapture,
|
||||
input_handler: &mut InputHandler,
|
||||
) -> Result<()> {
|
||||
let ws_url = config.ws_url(session_id);
|
||||
info!("connecting to {}", ws_url);
|
||||
|
||||
// Connect WebSocket.
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("WebSocket connect failed: {}", e))?;
|
||||
|
||||
info!("WebSocket connected");
|
||||
let (mut ws_write, mut ws_read) = ws_stream.split();
|
||||
|
||||
// Send AgentInfo to register with the server.
|
||||
let hostname = get_hostname();
|
||||
|
||||
let info_msg = protocol::agent_info_msg(session_id, agent_id, Some(resolution), Some(&hostname));
|
||||
ws_write
|
||||
.send(Message::Text(info_msg.into()))
|
||||
.await
|
||||
.context("failed to send agent info")?;
|
||||
|
||||
info!("registered with server as agent {} for session {}", agent_id, session_id);
|
||||
|
||||
// Create a channel for capture events.
|
||||
let (capture_tx, mut capture_rx) = mpsc::channel::<CaptureEvent>(8);
|
||||
|
||||
// Spawn the screen capture loop as a blocking task (scrap is sync).
|
||||
let capture_session_id = session_id.to_string();
|
||||
let frame_interval = config.frame_interval();
|
||||
let capture_handle = tokio::task::spawn_blocking(move || {
|
||||
capture_loop(capture_tx, capture_session_id, frame_interval);
|
||||
});
|
||||
|
||||
// Spawn heartbeat task.
|
||||
let heartbeat_session_id = session_id.to_string();
|
||||
let heartbeat_interval = config.heartbeat_interval();
|
||||
let (hb_tx, mut hb_rx) = mpsc::channel::<()>(1);
|
||||
let heartbeat_handle = tokio::spawn(async move {
|
||||
heartbeat_loop(hb_tx, heartbeat_session_id, heartbeat_interval).await;
|
||||
});
|
||||
|
||||
// Main select loop: read from WebSocket, read from capture, read from heartbeat.
|
||||
loop {
|
||||
tokio::select! {
|
||||
// ── Capture frame ready ─────────────────────────────────────
|
||||
capture_event = capture_rx.recv() => {
|
||||
match capture_event {
|
||||
Some(CaptureEvent::Frame(b64_data)) => {
|
||||
let frame_msg = protocol::display_frame_msg(session_id, b64_data);
|
||||
if let Err(e) = ws_write.send(Message::Text(frame_msg.into())).await {
|
||||
error!("failed to send frame: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(CaptureEvent::Error(err)) => {
|
||||
error!("capture error: {}", err);
|
||||
// Don't break — the capture loop handles retries internally.
|
||||
}
|
||||
None => {
|
||||
// Capture channel closed.
|
||||
warn!("capture channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Incoming WebSocket message from server ──────────────────
|
||||
ws_msg = ws_read.next() => {
|
||||
match ws_msg {
|
||||
Some(Ok(Message::Text(text))) => {
|
||||
handle_server_message(&text, session_id, input_handler)?;
|
||||
}
|
||||
Some(Ok(Message::Ping(data))) => {
|
||||
let _ = ws_write.send(Message::Pong(data)).await;
|
||||
}
|
||||
Some(Ok(Message::Close(reason))) => {
|
||||
info!("server closed connection: {:?}", reason);
|
||||
break;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
error!("WebSocket read error: {}", e);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
info!("WebSocket stream ended");
|
||||
break;
|
||||
}
|
||||
Some(Ok(Message::Pong(_))) | Some(Ok(Message::Binary(_))) | Some(Ok(Message::Frame(_))) => {
|
||||
// Ignore.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Heartbeat tick ──────────────────────────────────────────
|
||||
_ = hb_rx.recv() => {
|
||||
let msg = protocol::heartbeat_msg();
|
||||
if let Err(e) = ws_write.send(Message::Text(msg.into())).await {
|
||||
error!("failed to send heartbeat: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup.
|
||||
hb_tx.send(()).await.ok(); // Signal heartbeat to stop.
|
||||
drop(capture_tx); // Signal capture to stop.
|
||||
let _ = capture_handle.await;
|
||||
let _ = heartbeat_handle.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle an incoming text message from the server.
|
||||
fn handle_server_message(
|
||||
text: &str,
|
||||
session_id: &str,
|
||||
input_handler: &mut InputHandler,
|
||||
) -> Result<()> {
|
||||
let msg: AgentWsMessage = match serde_json::from_str(text) {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
warn!("invalid message from server: {} ({})", e, &text[..text.len().min(120)]);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
match msg {
|
||||
AgentWsMessage::ForwardHudCommand { command, params } => {
|
||||
// Execute the remote control command locally.
|
||||
if let Err(e) = input_handler.execute(&command, ¶ms) {
|
||||
warn!("HUD command '{}' failed: {}", command, e);
|
||||
}
|
||||
}
|
||||
AgentWsMessage::ForwardResize { width, height } => {
|
||||
info!(
|
||||
"viewer requested resize: {}x{} (session {})",
|
||||
width, height, session_id
|
||||
);
|
||||
// Future: could resize the virtual display here.
|
||||
}
|
||||
AgentWsMessage::StreamControl { action } => {
|
||||
info!("stream control: {} (session {})", action, session_id);
|
||||
// Future: implement pause/resume streaming.
|
||||
}
|
||||
AgentWsMessage::Ack { message } => {
|
||||
info!("server ack: {}", message);
|
||||
}
|
||||
AgentWsMessage::Error { message } => {
|
||||
error!("server error: {}", message);
|
||||
}
|
||||
// Messages the agent sends (shouldn't receive these).
|
||||
_ => {
|
||||
warn!("unexpected message type from server: {:?}", text.chars().take(60).collect::<String>());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Screen capture loop — runs on a blocking thread.
|
||||
///
|
||||
/// Captures frames at the target FPS, encodes them, and sends them through
|
||||
/// the channel. Handles capture failures gracefully (logs and retries).
|
||||
fn capture_loop(tx: mpsc::Sender<CaptureEvent>, session_id: String, frame_interval: std::time::Duration) {
|
||||
info!("capture loop started for session {}", session_id);
|
||||
|
||||
// We need a local ScreenCapture instance since it's !Send across async boundaries
|
||||
// in some configurations. Re-create it here from the display config.
|
||||
// Note: the main thread passes frame data through the channel, so we need
|
||||
// to manage capture state locally.
|
||||
let display_idx = 0; // Primary display.
|
||||
let quality = 60; // Default quality.
|
||||
|
||||
let mut capturer = match ScreenCapture::new(display_idx, quality) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
let _ = tx.send(CaptureEvent::Error(format!("capture init failed: {}", e))).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut consecutive_errors = 0u32;
|
||||
let max_consecutive_errors = 50;
|
||||
|
||||
loop {
|
||||
// Check if the channel is still open (receiver dropped).
|
||||
if tx.is_closed() {
|
||||
info!("capture loop: channel closed, exiting");
|
||||
break;
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
match capturer.capture_frame() {
|
||||
Ok(b64_data) => {
|
||||
consecutive_errors = 0;
|
||||
if tx.try_send(CaptureEvent::Frame(b64_data)).is_err() {
|
||||
// Channel full or closed — drop this frame.
|
||||
if tx.is_closed() {
|
||||
break;
|
||||
}
|
||||
log::trace!("capture loop: channel full, dropping frame");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
consecutive_errors += 1;
|
||||
warn!("capture error ({} consecutive): {}", consecutive_errors, e);
|
||||
|
||||
if consecutive_errors >= max_consecutive_errors {
|
||||
let _ = tx.blocking_send(CaptureEvent::Error(
|
||||
format!("too many consecutive capture errors ({}): {}", max_consecutive_errors, e)
|
||||
));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sleep for remaining frame time.
|
||||
let elapsed = start.elapsed();
|
||||
let sleep_duration = frame_interval.saturating_sub(elapsed);
|
||||
if !sleep_duration.is_zero() {
|
||||
std::thread::sleep(sleep_duration);
|
||||
}
|
||||
}
|
||||
|
||||
info!("capture loop ended for session {}", session_id);
|
||||
}
|
||||
|
||||
/// Heartbeat loop — sends periodic pings to keep the connection alive.
|
||||
async fn heartbeat_loop(
|
||||
tx: mpsc::Sender<()>,
|
||||
session_id: String,
|
||||
interval: std::time::Duration,
|
||||
) {
|
||||
info!("heartbeat loop started for session {}", session_id);
|
||||
let mut interval = tokio::time::interval(interval);
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
if tx.send(()).await.is_err() {
|
||||
info!("heartbeat loop: channel closed, exiting");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session via the REST API and return its ID.
|
||||
async fn create_session_via_rest(config: &AgentConfig) -> Result<String> {
|
||||
let url = format!("{}/sessions", config.api_base());
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({}))
|
||||
.send()
|
||||
.await
|
||||
.context("REST request to create session failed")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"failed to create session: HTTP {} — {}",
|
||||
status,
|
||||
body
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct CreateResponse {
|
||||
ok: bool,
|
||||
data: Option<SessionData>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct SessionData {
|
||||
id: String,
|
||||
}
|
||||
|
||||
let resp: CreateResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("failed to parse session creation response")?;
|
||||
|
||||
match resp {
|
||||
CreateResponse {
|
||||
ok: true,
|
||||
data: Some(session),
|
||||
..
|
||||
} => {
|
||||
info!("session created: {}", session.id);
|
||||
Ok(session.id)
|
||||
}
|
||||
CreateResponse {
|
||||
error: Some(err), ..
|
||||
} => anyhow::bail!("server rejected session creation: {}", err),
|
||||
_ => anyhow::bail!("unexpected session creation response"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the local hostname, using environment variables as a fallback.
|
||||
fn get_hostname() -> String {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
std::env::var("HOSTNAME")
|
||||
.or_else(|_| std::env::var("HOST"))
|
||||
.unwrap_or_else(|_| "unknown".into())
|
||||
}
|
||||
#[cfg(windows)]
|
||||
{
|
||||
std::env::var("COMPUTERNAME")
|
||||
.unwrap_or_else(|_| "unknown".into())
|
||||
}
|
||||
#[cfg(not(any(unix, windows)))]
|
||||
{
|
||||
"unknown".into()
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user