wactorz_interfaces/
ws.rs

1//! WebSocket routes for the Wactorz server.
2//!
3//! Two routes are mounted under the same axum `Router`:
4//!
5//! - `/ws`   — Python-compatible aggregated-state bridge.
6//!   Compatible with `monitor.html` (and any client expecting
7//!   `full_snapshot` / `patch` / `delete_agent` JSON messages).
8//!
9//! - `/mqtt` — Transparent WebSocket proxy to the Mosquitto broker's WS
10//!   listener (configurable host/port, default `localhost:9001`).
11//!   Compatible with `mqtt.js` / `frontend/dist/index.html`.
12//!
13//! Together these two routes ensure **any combination** of
14//! `python|rust` backend × `monitor.html|frontend/dist/index.html` frontend
15//! works without any client-side changes.
16//!
17//! ## `/ws` message protocol  (mirrors `monitor_server.py`)
18//!
19//! **Server → browser** on connect:
20//! ```json
21//! { "type": "full_snapshot", "state": { "agents": [...], "nodes": [...], ... } }
22//! ```
23//! **Server → browser** on MQTT event:
24//! ```json
25//! { "type": "patch", "event": { ... }, "state": { ... } }
26//! ```
27//! **Server → browser** after delete command:
28//! ```json
29//! { "type": "delete_agent", "agent_id": "...", "state": { ... } }
30//! ```
31//! **Browser → server** (commands):
32//! ```json
33//! { "type": "command", "command": "pause|stop|resume|delete", "agent_id": "..." }
34//! ```
35
36use std::collections::HashMap;
37use std::sync::Arc;
38use std::time::{SystemTime, UNIX_EPOCH};
39
40use axum::{
41    Router,
42    extract::{
43        State,
44        ws::{Message, WebSocket, WebSocketUpgrade},
45    },
46    http::HeaderMap,
47    response::IntoResponse,
48    routing::get,
49};
50use futures_util::{SinkExt, StreamExt};
51use serde::{Deserialize, Serialize};
52use serde_json::{Value, json};
53use tokio::sync::{Mutex, broadcast};
54
55use wactorz_mqtt::MqttClient;
56
57// ── Internal MQTT envelope (Rust MQTT loop → WS state aggregator) ─────────────
58
59/// Raw MQTT message forwarded from the broker event loop.
60/// Consumed by [`WsBridge::spawn_monitor_task`]; not sent to browser clients.
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct WsEnvelope {
63    pub topic: String,
64    pub payload: Value,
65}
66
67// ── Monitor state ─────────────────────────────────────────────────────────────
68
69fn now_secs() -> f64 {
70    SystemTime::now()
71        .duration_since(UNIX_EPOCH)
72        .unwrap_or_default()
73        .as_secs_f64()
74}
75
76/// Mirrors the in-memory state maintained by Python's `monitor_server.py`.
77#[derive(Debug, Default)]
78pub struct MonitorState {
79    agents: HashMap<String, Value>,
80    nodes: HashMap<String, Value>,
81    alerts: Vec<Value>,
82    log_feed: Vec<Value>,
83    system_health: Value,
84}
85
86impl MonitorState {
87    /// Serialisable snapshot sent to browser clients.
88    pub fn snapshot(&self) -> Value {
89        let agents: Vec<Value> = self.agents.values().cloned().collect();
90        let nodes: Vec<Value> = self.nodes.values().cloned().collect();
91        let total_cost: f64 = self
92            .agents
93            .values()
94            .filter_map(|a| a.get("cost_usd").and_then(|v| v.as_f64()))
95            .sum();
96        let alert_end = self.alerts.len().min(10);
97        let log_end = self.log_feed.len().min(20);
98        json!({
99            "agents":          agents,
100            "nodes":           nodes,
101            "alerts":          &self.alerts[..alert_end],
102            "log_feed":        &self.log_feed[..log_end],
103            "system_health":   self.system_health,
104            "total_cost_usd":  (total_cost * 1_000_000.0).round() / 1_000_000.0,
105        })
106    }
107
108    fn update_agent(&mut self, agent_id: &str, key: &str, data: Value) {
109        let short = &agent_id[..agent_id.len().min(8)];
110        let entry = self.agents.entry(agent_id.to_string()).or_insert_with(|| {
111            json!({
112                "agent_id":   agent_id,
113                "name":       short,
114                "first_seen": now_secs(),
115            })
116        });
117        if let Some(obj) = entry.as_object_mut() {
118            obj.insert(key.to_string(), data);
119            obj.insert("last_update".to_string(), json!(now_secs()));
120        }
121    }
122
123    fn add_log(&mut self, entry: Value) {
124        self.log_feed.insert(0, entry);
125        if self.log_feed.len() > 100 {
126            self.log_feed.pop();
127        }
128    }
129
130    /// Parse one MQTT message and update internal state.
131    ///
132    /// Returns `Some((event, is_heartbeat))` when something should be
133    /// broadcast, or `None` when the topic is not recognised.
134    /// `is_heartbeat` suppresses the event from the browser's log feed
135    /// (mirrors Python behaviour).
136    pub fn parse_topic(&mut self, topic: &str, payload: Value) -> Option<(Value, bool)> {
137        let parts: Vec<&str> = topic.split('/').collect();
138
139        // ── system/# ────────────────────────────────────────────────────────
140        if parts[0] == "system" && parts.len() >= 2 {
141            match parts[1] {
142                "health" => {
143                    self.system_health = payload.clone();
144                }
145                "alerts" => {
146                    self.alerts.insert(0, payload.clone());
147                    if self.alerts.len() > 50 {
148                        self.alerts.pop();
149                    }
150                }
151                _ => {}
152            }
153            return Some((
154                json!({
155                    "type":    "system",
156                    "subtype": parts[1],
157                    "data":    payload,
158                }),
159                false,
160            ));
161        }
162
163        // ── agents/{id}/{metric} ─────────────────────────────────────────────
164        if parts[0] == "agents" && parts.len() >= 3 {
165            let agent_id = parts[1];
166            let metric = parts[2];
167
168            match metric {
169                "status" => {
170                    self.update_agent(agent_id, "status", payload.clone());
171                    if let Some(obj) = payload.as_object()
172                        && let Some(entry) = self.agents.get_mut(agent_id)
173                        && let Some(e) = entry.as_object_mut()
174                    {
175                        if let Some(n) = obj.get("name") {
176                            e.insert("name".into(), n.clone());
177                        }
178                        if let Some(s) = obj.get("state") {
179                            e.insert("state".into(), s.clone());
180                        }
181                    }
182                    self.add_log(json!({
183                        "type":      "status",
184                        "agent_id":  agent_id,
185                        "status":    payload,
186                        "timestamp": now_secs(),
187                    }));
188                }
189                "heartbeat" => {
190                    self.update_agent(agent_id, "heartbeat", payload.clone());
191                    if let Some(obj) = payload.as_object() {
192                        let short = &agent_id[..agent_id.len().min(8)];
193                        let name = obj.get("name").and_then(|v| v.as_str()).unwrap_or(short);
194                        if let Some(entry) = self.agents.get_mut(agent_id)
195                            && let Some(e) = entry.as_object_mut()
196                        {
197                            e.insert("name".into(), json!(name));
198                            for k in &["cpu", "state"] {
199                                if let Some(v) = obj.get(*k) {
200                                    e.insert(k.to_string(), v.clone());
201                                }
202                            }
203                            if let Some(v) = obj.get("memory_mb") {
204                                e.insert("mem".into(), v.clone());
205                            }
206                            if let Some(v) = obj.get("task") {
207                                e.insert("task".into(), v.clone());
208                            }
209                        }
210                    }
211                    // heartbeat → broadcast state update but suppress from log_feed
212                    return Some((
213                        json!({
214                            "type":     "agent",
215                            "agent_id": agent_id,
216                            "metric":   metric,
217                            "data":     payload,
218                        }),
219                        true,
220                    ));
221                }
222                "metrics" => {
223                    self.update_agent(agent_id, "metrics", payload.clone());
224                    if let Some(obj) = payload.as_object()
225                        && let Some(entry) = self.agents.get_mut(agent_id)
226                        && let Some(e) = entry.as_object_mut()
227                    {
228                        for k in &[
229                            "messages_processed",
230                            "cost_usd",
231                            "input_tokens",
232                            "output_tokens",
233                        ] {
234                            if let Some(v) = obj.get(*k) {
235                                e.insert(k.to_string(), v.clone());
236                            }
237                        }
238                    }
239                }
240                "logs" => {
241                    let mut log = json!({
242                        "type":      "log",
243                        "agent_id":  agent_id,
244                        "timestamp": now_secs(),
245                    });
246                    if let (Some(src), Some(dst)) = (payload.as_object(), log.as_object_mut()) {
247                        for (k, v) in src {
248                            dst.entry(k.clone()).or_insert(v.clone());
249                        }
250                    }
251                    self.add_log(log);
252                }
253                "spawned" => {
254                    let mut log = json!({
255                        "type":      "spawned",
256                        "agent_id":  agent_id,
257                        "timestamp": now_secs(),
258                    });
259                    if let (Some(src), Some(dst)) = (payload.as_object(), log.as_object_mut()) {
260                        for (k, v) in src {
261                            dst.entry(k.clone()).or_insert(v.clone());
262                        }
263                    }
264                    self.add_log(log);
265                }
266                "completed" => {
267                    self.update_agent(agent_id, "last_completed", payload.clone());
268                    self.add_log(json!({
269                        "type":      "completed",
270                        "agent_id":  agent_id,
271                        "timestamp": now_secs(),
272                    }));
273                }
274                "alert" => {
275                    let short = &agent_id[..agent_id.len().min(8)];
276                    let known_name = self
277                        .agents
278                        .get(agent_id)
279                        .and_then(|a| a.get("name"))
280                        .and_then(|v| v.as_str())
281                        .unwrap_or(short)
282                        .to_string();
283                    let enriched = if let Some(obj) = payload.as_object() {
284                        let mut e = obj.clone();
285                        e.insert("agent_id".into(), json!(agent_id));
286                        e.entry("name".to_string())
287                            .or_insert_with(|| json!(&known_name));
288                        Value::Object(e)
289                    } else {
290                        json!({ "agent_id": agent_id })
291                    };
292                    let severity = enriched
293                        .get("severity")
294                        .and_then(|v| v.as_str())
295                        .unwrap_or("warning")
296                        .to_string();
297                    let name = enriched
298                        .get("name")
299                        .and_then(|v| v.as_str())
300                        .unwrap_or(&known_name)
301                        .to_string();
302                    self.alerts.insert(0, enriched);
303                    if self.alerts.len() > 50 {
304                        self.alerts.pop();
305                    }
306                    self.add_log(json!({
307                        "type":      "alert",
308                        "agent_id":  agent_id,
309                        "name":      name,
310                        "message":   format!("{name} unresponsive ({severity})"),
311                        "timestamp": now_secs(),
312                    }));
313                }
314                _ => {}
315            }
316            return Some((
317                json!({
318                    "type":     "agent",
319                    "agent_id": agent_id,
320                    "metric":   metric,
321                    "data":     payload,
322                }),
323                false,
324            ));
325        }
326
327        // ── nodes/{name}/heartbeat ───────────────────────────────────────────
328        if parts[0] == "nodes" && parts.len() >= 3 && parts[2] == "heartbeat" {
329            let node_name = parts[1];
330            if let Some(obj) = payload.as_object() {
331                self.nodes.insert(
332                    node_name.to_string(),
333                    json!({
334                        "node":      node_name,
335                        "agents":    obj.get("agents").cloned().unwrap_or(json!([])),
336                        "last_seen": now_secs(),
337                        "online":    true,
338                        "node_id":   obj.get("node_id").cloned().unwrap_or(json!("")),
339                    }),
340                );
341            }
342            return Some((
343                json!({
344                    "type":      "node",
345                    "node_name": node_name,
346                    "data":      payload,
347                }),
348                false,
349            ));
350        }
351
352        None
353    }
354}
355
356// ── Shared bridge state ───────────────────────────────────────────────────────
357
358#[derive(Clone)]
359pub struct BridgeState {
360    /// MQTT → WS broadcast (raw envelopes, consumed by monitor task).
361    pub mqtt_tx: broadcast::Sender<WsEnvelope>,
362    /// Aggregated monitor state shared across all `/ws` connections.
363    pub monitor: Arc<Mutex<MonitorState>>,
364    /// Broadcast channel: serialised JSON patches to all `/ws` clients.
365    pub monitor_tx: broadcast::Sender<String>,
366    /// MQTT client for publishing commands received from the browser.
367    pub mqtt_client: Arc<MqttClient>,
368    /// Mosquitto WebSocket host (for `/mqtt` proxy).
369    pub mqtt_host: String,
370    /// Mosquitto WebSocket port (for `/mqtt` proxy, default 9001).
371    pub mqtt_ws_port: u16,
372}
373
374// ── WsBridge ──────────────────────────────────────────────────────────────────
375
376pub struct WsBridge {
377    state: BridgeState,
378}
379
380impl WsBridge {
381    pub fn new(
382        mqtt_tx: broadcast::Sender<WsEnvelope>,
383        mqtt_client: Arc<MqttClient>,
384        mqtt_host: String,
385        mqtt_ws_port: u16,
386    ) -> Self {
387        let (monitor_tx, _) = broadcast::channel::<String>(256);
388        Self {
389            state: BridgeState {
390                mqtt_tx,
391                monitor: Arc::new(Mutex::new(MonitorState::default())),
392                monitor_tx,
393                mqtt_client,
394                mqtt_host,
395                mqtt_ws_port,
396            },
397        }
398    }
399
400    /// Spawn a background task that:
401    ///
402    /// 1. Consumes raw MQTT envelopes from the broadcast channel.
403    /// 2. Updates [`MonitorState`].
404    /// 3. Broadcasts Python-compatible JSON patches to all `/ws` clients.
405    pub fn spawn_monitor_task(&self) {
406        let mut rx = self.state.mqtt_tx.subscribe();
407        let monitor = Arc::clone(&self.state.monitor);
408        let monitor_tx = self.state.monitor_tx.clone();
409
410        tokio::spawn(async move {
411            while let Ok(envelope) = rx.recv().await {
412                let msg = {
413                    let mut st = monitor.lock().await;
414                    match st.parse_topic(&envelope.topic, envelope.payload) {
415                        None => continue,
416                        Some((event, is_heartbeat)) => {
417                            let snap = st.snapshot();
418                            let log_event = if is_heartbeat { Value::Null } else { event };
419                            serde_json::to_string(&json!({
420                                "type":  "patch",
421                                "event": log_event,
422                                "state": snap,
423                            }))
424                            .unwrap_or_default()
425                        }
426                    }
427                };
428                if !msg.is_empty() {
429                    let _ = monitor_tx.send(msg);
430                }
431            }
432        });
433    }
434
435    /// Build the axum `Router` with `/ws` and `/mqtt` routes.
436    pub fn router(&self) -> Router {
437        Router::new()
438            .route("/ws", get(ws_handler))
439            .route("/mqtt", get(mqtt_proxy_handler))
440            .with_state(self.state.clone())
441    }
442}
443
444// ── /ws handler: Python-compatible aggregated state ───────────────────────────
445
446async fn ws_handler(ws: WebSocketUpgrade, State(state): State<BridgeState>) -> impl IntoResponse {
447    ws.on_upgrade(move |socket| handle_ws_socket(socket, state))
448}
449
450async fn handle_ws_socket(socket: WebSocket, state: BridgeState) {
451    let mut monitor_rx = state.monitor_tx.subscribe();
452    let (mut ws_send, mut ws_recv) = socket.split();
453
454    // Send a full state snapshot immediately on connect (mirrors Python behaviour)
455    let snap_json = {
456        let st = state.monitor.lock().await;
457        serde_json::to_string(&json!({
458            "type":  "full_snapshot",
459            "state": st.snapshot(),
460        }))
461        .unwrap_or_default()
462    };
463    if ws_send.send(Message::Text(snap_json)).await.is_err() {
464        return;
465    }
466
467    // Forward broadcast patches to this client
468    let send_task = tokio::spawn(async move {
469        while let Ok(json) = monitor_rx.recv().await {
470            if ws_send.send(Message::Text(json)).await.is_err() {
471                break;
472            }
473        }
474    });
475
476    // Handle inbound messages (commands from the browser)
477    while let Some(Ok(msg)) = ws_recv.next().await {
478        match msg {
479            Message::Close(_) => break,
480            Message::Text(text) => {
481                handle_browser_command(&text, &state).await;
482            }
483            _ => {}
484        }
485    }
486    send_task.abort();
487}
488
489async fn handle_browser_command(text: &str, state: &BridgeState) {
490    let Ok(cmd) = serde_json::from_str::<Value>(text) else {
491        return;
492    };
493    if cmd.get("type").and_then(|v| v.as_str()) != Some("command") {
494        return;
495    }
496
497    let Some(command) = cmd.get("command").and_then(|v| v.as_str()) else {
498        return;
499    };
500    let Some(agent_id) = cmd.get("agent_id").and_then(|v| v.as_str()) else {
501        return;
502    };
503
504    let valid = ["pause", "stop", "resume", "delete"];
505    if !valid.contains(&command) {
506        tracing::warn!("[ws] Unknown command: {command}");
507        return;
508    }
509
510    tracing::info!(
511        "[ws] {} -> {}",
512        command.to_uppercase(),
513        &agent_id[..agent_id.len().min(8)]
514    );
515
516    // Publish command to MQTT
517    let mqtt_payload = json!({
518        "command":   command,
519        "sender":    "monitor-dashboard",
520        "timestamp": now_secs(),
521    });
522    let topic = format!("agents/{agent_id}/commands");
523    if let Err(e) = state.mqtt_client.publish_json(&topic, &mqtt_payload).await {
524        tracing::error!("[ws] MQTT publish failed: {e}");
525        return;
526    }
527
528    // Optimistic state update + broadcast
529    let msg = {
530        let mut st = state.monitor.lock().await;
531        if command == "delete" {
532            st.agents.remove(agent_id);
533            let snap = st.snapshot();
534            serde_json::to_string(&json!({
535                "type":     "delete_agent",
536                "agent_id": agent_id,
537                "state":    snap,
538            }))
539            .unwrap_or_default()
540        } else {
541            if let Some(entry) = st.agents.get_mut(agent_id)
542                && let Some(e) = entry.as_object_mut()
543            {
544                let new_state = match command {
545                    "stop" => "stopped",
546                    "pause" => "paused",
547                    "resume" => "running",
548                    _ => return,
549                };
550                e.insert("state".into(), json!(new_state));
551            }
552            let snap = st.snapshot();
553            serde_json::to_string(&json!({
554                "type":  "patch",
555                "state": snap,
556            }))
557            .unwrap_or_default()
558        }
559    };
560
561    if !msg.is_empty() {
562        let _ = state.monitor_tx.send(msg);
563    }
564}
565
566// ── /mqtt handler: transparent proxy to Mosquitto WS ─────────────────────────
567//
568// The browser's mqtt.js speaks the MQTT binary protocol over WebSocket.
569// We forward every frame verbatim to/from Mosquitto's WS listener (port 9001
570// by default, or whatever --mqtt-ws-port is set to).
571//
572// Supports the "mqtt" subprotocol header so mqtt.js is satisfied.
573
574async fn mqtt_proxy_handler(
575    ws: WebSocketUpgrade,
576    headers: HeaderMap,
577    State(state): State<BridgeState>,
578) -> impl IntoResponse {
579    // Echo back whichever MQTT sub-protocol the client announced
580    let proto = headers
581        .get("sec-websocket-protocol")
582        .and_then(|v| v.to_str().ok())
583        .map(|s| s.to_string());
584
585    let ws = ws.protocols(["mqtt", "mqttv3.1"]);
586    ws.on_upgrade(move |socket| proxy_to_mosquitto(socket, state, proto))
587}
588
589async fn proxy_to_mosquitto(socket: WebSocket, state: BridgeState, proto: Option<String>) {
590    use tokio_tungstenite::connect_async;
591    use tokio_tungstenite::tungstenite::Message as TMsg;
592    use tokio_tungstenite::tungstenite::http::Request;
593
594    let upstream_url = format!("ws://{}:{}/", state.mqtt_host, state.mqtt_ws_port);
595
596    // Build upstream WS request, forwarding the MQTT sub-protocol header
597    let request = {
598        let mut builder = Request::builder().uri(&upstream_url);
599        let p = proto.as_deref().unwrap_or("mqtt");
600        builder = builder.header("Sec-WebSocket-Protocol", p);
601        match builder.body(()) {
602            Ok(r) => r,
603            Err(e) => {
604                tracing::warn!("[mqtt-proxy] bad request: {e}");
605                return;
606            }
607        }
608    };
609
610    let upstream = match connect_async(request).await {
611        Ok((stream, _)) => stream,
612        Err(e) => {
613            tracing::warn!(
614                "[mqtt-proxy] upstream connect failed ({}): {e}",
615                upstream_url
616            );
617            return;
618        }
619    };
620
621    let (mut up_send, mut up_recv) = upstream.split();
622    let (mut cl_send, mut cl_recv) = socket.split();
623
624    // upstream → client
625    let up_to_cl = tokio::spawn(async move {
626        while let Some(Ok(msg)) = up_recv.next().await {
627            let out = match msg {
628                TMsg::Binary(b) => Message::Binary(b),
629                TMsg::Text(t) => Message::Text(t),
630                TMsg::Close(_) => break,
631                _ => continue,
632            };
633            if cl_send.send(out).await.is_err() {
634                break;
635            }
636        }
637    });
638
639    // client → upstream
640    while let Some(Ok(msg)) = cl_recv.next().await {
641        let fwd = match msg {
642            Message::Binary(b) => TMsg::Binary(b),
643            Message::Text(t) => TMsg::Text(t),
644            Message::Close(_) => break,
645            _ => continue,
646        };
647        if up_send.send(fwd).await.is_err() {
648            break;
649        }
650    }
651
652    up_to_cl.abort();
653}