wactorz_agents/
ml_agent.rs1use anyhow::Result;
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13use tokio::sync::mpsc;
14
15use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct InferenceResult {
20 pub label: String,
22 pub confidence: f32,
24 pub metadata: serde_json::Value,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum MlBackend {
32 HttpService { url: String },
34 Onnx { model_path: String },
36 Candle { model_path: String },
38}
39
40impl Default for MlBackend {
41 fn default() -> Self {
42 MlBackend::HttpService {
43 url: "http://localhost:5000/infer".into(),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct MlConfig {
51 pub backend: MlBackend,
52 pub confidence_threshold: f32,
54 pub batch_size: usize,
56}
57
58impl Default for MlConfig {
59 fn default() -> Self {
60 Self {
61 backend: MlBackend::default(),
62 confidence_threshold: 0.5,
63 batch_size: 1,
64 }
65 }
66}
67
68pub struct MlAgent {
70 config: ActorConfig,
71 ml_config: MlConfig,
72 http: reqwest::Client,
73 state: ActorState,
74 metrics: Arc<ActorMetrics>,
75 mailbox_tx: mpsc::Sender<Message>,
76 mailbox_rx: Option<mpsc::Receiver<Message>>,
77 publisher: Option<EventPublisher>,
78}
79
80impl MlAgent {
81 pub fn new(config: ActorConfig, ml_config: MlConfig) -> Self {
82 let (tx, rx) = mpsc::channel(config.mailbox_capacity);
83 Self {
84 config,
85 ml_config,
86 http: reqwest::Client::new(),
87 state: ActorState::Initializing,
88 metrics: Arc::new(ActorMetrics::new()),
89 mailbox_tx: tx,
90 mailbox_rx: Some(rx),
91 publisher: None,
92 }
93 }
94
95 pub fn with_publisher(mut self, p: EventPublisher) -> Self {
97 self.publisher = Some(p);
98 self
99 }
100
101 pub async fn infer(&self, input: &serde_json::Value) -> Result<Vec<InferenceResult>> {
105 match &self.ml_config.backend {
106 MlBackend::HttpService { url } => self.infer_http(url, input).await,
107 MlBackend::Onnx { model_path } => {
108 anyhow::bail!("ONNX backend not yet implemented (model: {model_path})")
109 }
110 MlBackend::Candle { model_path } => {
111 anyhow::bail!("Candle backend not yet implemented (model: {model_path})")
112 }
113 }
114 }
115
116 async fn infer_http(
117 &self,
118 url: &str,
119 input: &serde_json::Value,
120 ) -> Result<Vec<InferenceResult>> {
121 let resp = self.http.post(url).json(input).send().await?;
122 if !resp.status().is_success() {
123 let s = resp.status();
124 let t = resp.text().await.unwrap_or_default();
125 anyhow::bail!("ML service {s}: {t}");
126 }
127 let mut results: Vec<InferenceResult> = resp.json().await?;
128 results.retain(|r| r.confidence >= self.ml_config.confidence_threshold);
129 Ok(results)
130 }
131}
132
133#[async_trait]
134impl Actor for MlAgent {
135 fn id(&self) -> String {
136 self.config.id.clone()
137 }
138 fn name(&self) -> &str {
139 &self.config.name
140 }
141 fn state(&self) -> ActorState {
142 self.state.clone()
143 }
144 fn metrics(&self) -> Arc<ActorMetrics> {
145 Arc::clone(&self.metrics)
146 }
147 fn mailbox(&self) -> mpsc::Sender<Message> {
148 self.mailbox_tx.clone()
149 }
150 fn is_protected(&self) -> bool {
151 self.config.protected
152 }
153
154 async fn handle_message(&mut self, message: Message) -> Result<()> {
155 use wactorz_core::message::MessageType;
156 let input = match &message.payload {
157 MessageType::Task { payload, .. } => payload.clone(),
158 MessageType::Text { content } => serde_json::Value::String(content.clone()),
159 _ => return Ok(()),
160 };
161 let results = self.infer(&input).await?;
162 if let Some(pub_) = &self.publisher {
163 pub_.publish(
164 wactorz_mqtt::topics::detections(&self.config.id),
165 &serde_json::json!({ "results": results }),
166 );
167 }
168 Ok(())
169 }
170
171 async fn run(&mut self) -> Result<()> {
172 self.on_start().await?;
173 self.state = ActorState::Running;
174 let mut rx = self
175 .mailbox_rx
176 .take()
177 .ok_or_else(|| anyhow::anyhow!("MlAgent already running"))?;
178 let mut hb = tokio::time::interval(std::time::Duration::from_secs(
179 self.config.heartbeat_interval_secs,
180 ));
181 hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
182 loop {
183 tokio::select! {
184 biased;
185 msg = rx.recv() => {
186 match msg {
187 None => break,
188 Some(m) => {
189 self.metrics.record_received();
190 if let wactorz_core::message::MessageType::Command {
191 command: wactorz_core::message::ActorCommand::Stop
192 } = &m.payload {
193 break;
194 }
195 match self.handle_message(m).await {
196 Ok(_) => self.metrics.record_processed(),
197 Err(e) => {
198 tracing::error!("[{}] {e}", self.config.name);
199 self.metrics.record_failed();
200 }
201 }
202 }
203 }
204 }
205 _ = hb.tick() => {
206 self.metrics.record_heartbeat();
207 if let Err(e) = self.on_heartbeat().await {
208 tracing::error!("[{}] heartbeat: {e}", self.config.name);
209 }
210 }
211 }
212 }
213 self.state = ActorState::Stopped;
214 self.on_stop().await
215 }
216}