wactorz_agents/
ml_agent.rs

1//! ML inference base agent.
2//!
3//! [`MlAgent`] provides a base implementation for machine-learning agents that
4//! run local inference (e.g. ONNX, PyTorch via candle, or HTTP microservices).
5//!
6//! Subclasses override [`MlAgent::infer`] to implement model-specific logic.
7//! Results are published to the MQTT `agents/{id}/detections` topic.
8
9use anyhow::Result;
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
16
17/// A generic inference result (can be subclassed via config).
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct InferenceResult {
20    /// Model-specific label/class.
21    pub label: String,
22    /// Confidence score in `[0.0, 1.0]`.
23    pub confidence: f32,
24    /// Arbitrary model-specific metadata.
25    pub metadata: serde_json::Value,
26}
27
28/// Backend selection for the ML agent.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum MlBackend {
32    /// Remote HTTP microservice (POST JSON, receive JSON).
33    HttpService { url: String },
34    /// ONNX runtime (local file path to `.onnx` model).
35    Onnx { model_path: String },
36    /// candle (Rust-native PyTorch-like) — for future use.
37    Candle { model_path: String },
38}
39
40impl Default for MlBackend {
41    fn default() -> Self {
42        MlBackend::HttpService {
43            url: "http://localhost:5000/infer".into(),
44        }
45    }
46}
47
48/// Configuration for an ML agent.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct MlConfig {
51    pub backend: MlBackend,
52    /// Confidence threshold below which results are discarded.
53    pub confidence_threshold: f32,
54    /// Maximum batch size for inference.
55    pub batch_size: usize,
56}
57
58impl Default for MlConfig {
59    fn default() -> Self {
60        Self {
61            backend: MlBackend::default(),
62            confidence_threshold: 0.5,
63            batch_size: 1,
64        }
65    }
66}
67
68/// Base ML inference actor.
69pub struct MlAgent {
70    config: ActorConfig,
71    ml_config: MlConfig,
72    http: reqwest::Client,
73    state: ActorState,
74    metrics: Arc<ActorMetrics>,
75    mailbox_tx: mpsc::Sender<Message>,
76    mailbox_rx: Option<mpsc::Receiver<Message>>,
77    publisher: Option<EventPublisher>,
78}
79
80impl MlAgent {
81    pub fn new(config: ActorConfig, ml_config: MlConfig) -> Self {
82        let (tx, rx) = mpsc::channel(config.mailbox_capacity);
83        Self {
84            config,
85            ml_config,
86            http: reqwest::Client::new(),
87            state: ActorState::Initializing,
88            metrics: Arc::new(ActorMetrics::new()),
89            mailbox_tx: tx,
90            mailbox_rx: Some(rx),
91            publisher: None,
92        }
93    }
94
95    /// Attach an EventPublisher for MQTT output.
96    pub fn with_publisher(mut self, p: EventPublisher) -> Self {
97        self.publisher = Some(p);
98        self
99    }
100
101    /// Run inference on a raw input payload.
102    ///
103    /// Override this method in specialised ML agents.
104    pub async fn infer(&self, input: &serde_json::Value) -> Result<Vec<InferenceResult>> {
105        match &self.ml_config.backend {
106            MlBackend::HttpService { url } => self.infer_http(url, input).await,
107            MlBackend::Onnx { model_path } => {
108                anyhow::bail!("ONNX backend not yet implemented (model: {model_path})")
109            }
110            MlBackend::Candle { model_path } => {
111                anyhow::bail!("Candle backend not yet implemented (model: {model_path})")
112            }
113        }
114    }
115
116    async fn infer_http(
117        &self,
118        url: &str,
119        input: &serde_json::Value,
120    ) -> Result<Vec<InferenceResult>> {
121        let resp = self.http.post(url).json(input).send().await?;
122        if !resp.status().is_success() {
123            let s = resp.status();
124            let t = resp.text().await.unwrap_or_default();
125            anyhow::bail!("ML service {s}: {t}");
126        }
127        let mut results: Vec<InferenceResult> = resp.json().await?;
128        results.retain(|r| r.confidence >= self.ml_config.confidence_threshold);
129        Ok(results)
130    }
131}
132
133#[async_trait]
134impl Actor for MlAgent {
135    fn id(&self) -> String {
136        self.config.id.clone()
137    }
138    fn name(&self) -> &str {
139        &self.config.name
140    }
141    fn state(&self) -> ActorState {
142        self.state.clone()
143    }
144    fn metrics(&self) -> Arc<ActorMetrics> {
145        Arc::clone(&self.metrics)
146    }
147    fn mailbox(&self) -> mpsc::Sender<Message> {
148        self.mailbox_tx.clone()
149    }
150    fn is_protected(&self) -> bool {
151        self.config.protected
152    }
153
154    async fn handle_message(&mut self, message: Message) -> Result<()> {
155        use wactorz_core::message::MessageType;
156        let input = match &message.payload {
157            MessageType::Task { payload, .. } => payload.clone(),
158            MessageType::Text { content } => serde_json::Value::String(content.clone()),
159            _ => return Ok(()),
160        };
161        let results = self.infer(&input).await?;
162        if let Some(pub_) = &self.publisher {
163            pub_.publish(
164                wactorz_mqtt::topics::detections(&self.config.id),
165                &serde_json::json!({ "results": results }),
166            );
167        }
168        Ok(())
169    }
170
171    async fn run(&mut self) -> Result<()> {
172        self.on_start().await?;
173        self.state = ActorState::Running;
174        let mut rx = self
175            .mailbox_rx
176            .take()
177            .ok_or_else(|| anyhow::anyhow!("MlAgent already running"))?;
178        let mut hb = tokio::time::interval(std::time::Duration::from_secs(
179            self.config.heartbeat_interval_secs,
180        ));
181        hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
182        loop {
183            tokio::select! {
184                biased;
185                msg = rx.recv() => {
186                    match msg {
187                        None => break,
188                        Some(m) => {
189                            self.metrics.record_received();
190                            if let wactorz_core::message::MessageType::Command {
191                                command: wactorz_core::message::ActorCommand::Stop
192                            } = &m.payload {
193                                break;
194                            }
195                            match self.handle_message(m).await {
196                                Ok(_) => self.metrics.record_processed(),
197                                Err(e) => {
198                                    tracing::error!("[{}] {e}", self.config.name);
199                                    self.metrics.record_failed();
200                                }
201                            }
202                        }
203                    }
204                }
205                _ = hb.tick() => {
206                    self.metrics.record_heartbeat();
207                    if let Err(e) = self.on_heartbeat().await {
208                        tracing::error!("[{}] heartbeat: {e}", self.config.name);
209                    }
210                }
211            }
212        }
213        self.state = ActorState::Stopped;
214        self.on_stop().await
215    }
216}