wactorz_agents/
home_assistant_agent.rs

1//! Home Assistant integration agent.
2//!
3//! [`HomeAssistantAgent`] connects to a local Home Assistant instance via
4//! its REST API and WebSocket event bus.  It can query entity states, call
5//! services, and subscribe to state-change events.
6//!
7//! Configuration is read from environment variables:
8//! - `HA_URL`   — Home Assistant base URL (e.g. `http://homeassistant.local:8123`)
9//! - `HA_TOKEN` — Long-lived access token
10
11use anyhow::Result;
12use async_trait::async_trait;
13use std::sync::Arc;
14use tokio::sync::mpsc;
15
16use crate::llm_agent::{LlmAgent, LlmConfig};
17use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
18
19/// Home Assistant agent.
20pub struct HomeAssistantAgent {
21    config: ActorConfig,
22    ha_url: String,
23    ha_token: String,
24    http: reqwest::Client,
25    llm: Option<LlmAgent>,
26    state: ActorState,
27    metrics: Arc<ActorMetrics>,
28    mailbox_tx: mpsc::Sender<Message>,
29    mailbox_rx: Option<mpsc::Receiver<Message>>,
30    publisher: Option<EventPublisher>,
31}
32
33impl HomeAssistantAgent {
34    pub fn new(config: ActorConfig) -> Self {
35        let ha_url = std::env::var("HA_URL").unwrap_or_default();
36        let ha_token = std::env::var("HA_TOKEN").unwrap_or_default();
37        let (tx, rx) = mpsc::channel(config.mailbox_capacity);
38        Self {
39            config,
40            ha_url,
41            ha_token,
42            http: reqwest::Client::new(),
43            llm: None,
44            state: ActorState::Initializing,
45            metrics: Arc::new(ActorMetrics::new()),
46            mailbox_tx: tx,
47            mailbox_rx: Some(rx),
48            publisher: None,
49        }
50    }
51
52    pub fn with_publisher(mut self, p: EventPublisher) -> Self {
53        self.publisher = Some(p);
54        self
55    }
56
57    pub fn with_llm(mut self, llm_config: LlmConfig) -> Self {
58        let llm_cfg = ActorConfig::new(format!("{}-llm", self.config.name));
59        self.llm = Some(LlmAgent::new(llm_cfg, llm_config));
60        self
61    }
62
63    /// GET /api/states — return all entity states as JSON.
64    async fn get_states(&self) -> Result<serde_json::Value> {
65        let resp = self
66            .http
67            .get(format!("{}/api/states", self.ha_url))
68            .header("Authorization", format!("Bearer {}", self.ha_token))
69            .header("Content-Type", "application/json")
70            .send()
71            .await?;
72        Ok(resp.json().await?)
73    }
74
75    /// GET /api/states/<entity_id> — single entity state.
76    async fn get_state(&self, entity_id: &str) -> Result<serde_json::Value> {
77        let resp = self
78            .http
79            .get(format!("{}/api/states/{}", self.ha_url, entity_id))
80            .header("Authorization", format!("Bearer {}", self.ha_token))
81            .send()
82            .await?;
83        Ok(resp.json().await?)
84    }
85
86    /// POST /api/services/<domain>/<service> — call a HA service.
87    #[expect(dead_code)]
88    async fn call_service(
89        &self,
90        domain: &str,
91        service: &str,
92        data: serde_json::Value,
93    ) -> Result<serde_json::Value> {
94        let resp = self
95            .http
96            .post(format!(
97                "{}/api/services/{}/{}",
98                self.ha_url, domain, service
99            ))
100            .header("Authorization", format!("Bearer {}", self.ha_token))
101            .header("Content-Type", "application/json")
102            .json(&data)
103            .send()
104            .await?;
105        Ok(resp.json().await?)
106    }
107
108    async fn process_request(&mut self, text: &str) -> String {
109        // Simple keyword dispatch; LLM interprets if available
110        let lower = text.to_lowercase();
111
112        if lower.contains("states") || lower.contains("all entities") {
113            match self.get_states().await {
114                Ok(v) => format!(
115                    "HA states: {}",
116                    serde_json::to_string_pretty(&v).unwrap_or_else(|_| v.to_string())
117                ),
118                Err(e) => format!("HA error: {e}"),
119            }
120        } else if let Some(entity) = extract_entity_id(text) {
121            match self.get_state(&entity).await {
122                Ok(v) => format!("{entity}: {}", v["state"].as_str().unwrap_or("unknown")),
123                Err(e) => format!("HA error: {e}"),
124            }
125        } else if let Some(llm) = &mut self.llm {
126            let prompt = format!(
127                "You are a Home Assistant expert. The user said: \"{text}\"\n\
128                 Interpret this as a HA request and respond helpfully. \
129                 If you need to call a service, suggest: call_service(domain, service, {{data}})."
130            );
131            llm.complete(&prompt)
132                .await
133                .unwrap_or_else(|e| format!("LLM error: {e}"))
134        } else {
135            "I can query HA entity states. Try: 'get state light.living_room' or 'list all states'"
136                .into()
137        }
138    }
139
140    fn now_ms() -> u64 {
141        std::time::SystemTime::now()
142            .duration_since(std::time::UNIX_EPOCH)
143            .unwrap_or_default()
144            .as_millis() as u64
145    }
146}
147
148fn extract_entity_id(text: &str) -> Option<String> {
149    // Look for patterns like "light.living_room", "sensor.temperature", etc.
150    let words: Vec<&str> = text.split_whitespace().collect();
151    words
152        .iter()
153        .find(|w| w.contains('.') && !w.starts_with("http"))
154        .map(|s| s.to_string())
155}
156
157#[async_trait]
158impl Actor for HomeAssistantAgent {
159    fn id(&self) -> String {
160        self.config.id.clone()
161    }
162    fn name(&self) -> &str {
163        &self.config.name
164    }
165    fn state(&self) -> ActorState {
166        self.state.clone()
167    }
168    fn metrics(&self) -> Arc<ActorMetrics> {
169        Arc::clone(&self.metrics)
170    }
171    fn mailbox(&self) -> mpsc::Sender<Message> {
172        self.mailbox_tx.clone()
173    }
174
175    async fn on_start(&mut self) -> Result<()> {
176        self.state = ActorState::Running;
177        let connected = !self.ha_url.is_empty() && !self.ha_token.is_empty();
178        tracing::info!(
179            "[{}] HA agent started (connected={})",
180            self.config.name,
181            connected
182        );
183        if let Some(pub_) = &self.publisher {
184            pub_.publish(
185                wactorz_mqtt::topics::spawn(&self.config.id),
186                &serde_json::json!({
187                    "agentId":   self.config.id,
188                    "agentName": self.config.name,
189                    "agentType": "home_assistant",
190                    "haConnected": connected,
191                    "timestampMs": Self::now_ms(),
192                }),
193            );
194        }
195        Ok(())
196    }
197
198    async fn handle_message(&mut self, message: Message) -> Result<()> {
199        use wactorz_core::message::MessageType;
200        let text = match &message.payload {
201            MessageType::Text { content } => content.clone(),
202            MessageType::Task { description, .. } => description.clone(),
203            _ => return Ok(()),
204        };
205        let response = self.process_request(&text).await;
206        if let Some(pub_) = &self.publisher {
207            pub_.publish(
208                wactorz_mqtt::topics::chat(&self.config.id),
209                &serde_json::json!({
210                    "from":        self.config.name,
211                    "to":          message.from.as_deref().unwrap_or("user"),
212                    "content":     response,
213                    "timestampMs": Self::now_ms(),
214                }),
215            );
216        }
217        Ok(())
218    }
219
220    async fn on_heartbeat(&mut self) -> Result<()> {
221        if let Some(pub_) = &self.publisher {
222            pub_.publish(
223                wactorz_mqtt::topics::heartbeat(&self.config.id),
224                &serde_json::json!({
225                    "agentId":   self.config.id,
226                    "agentName": self.config.name,
227                    "state":     self.state,
228                    "haUrl":     self.ha_url,
229                    "timestampMs": Self::now_ms(),
230                }),
231            );
232        }
233        Ok(())
234    }
235
236    async fn run(&mut self) -> Result<()> {
237        self.on_start().await?;
238        let mut rx = self
239            .mailbox_rx
240            .take()
241            .ok_or_else(|| anyhow::anyhow!("HomeAssistantAgent already running"))?;
242        let mut hb = tokio::time::interval(std::time::Duration::from_secs(
243            self.config.heartbeat_interval_secs,
244        ));
245        hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
246        loop {
247            tokio::select! {
248                biased;
249                msg = rx.recv() => {
250                    match msg {
251                        None => break,
252                        Some(m) => {
253                            self.metrics.record_received();
254                            if let wactorz_core::message::MessageType::Command {
255                                command: wactorz_core::message::ActorCommand::Stop
256                            } = &m.payload { break; }
257                            match self.handle_message(m).await {
258                                Ok(_)  => self.metrics.record_processed(),
259                                Err(e) => { tracing::error!("[{}] {e}", self.config.name); self.metrics.record_failed(); }
260                            }
261                        }
262                    }
263                }
264                _ = hb.tick() => {
265                    self.metrics.record_heartbeat();
266                    if let Err(e) = self.on_heartbeat().await {
267                        tracing::error!("[{}] heartbeat: {e}", self.config.name);
268                    }
269                }
270            }
271        }
272        self.state = ActorState::Stopped;
273        self.on_stop().await
274    }
275}