wactorz_agents/
llm_agent.rs

1//! LLM provider abstraction.
2//!
3//! [`LlmAgent`] wraps multiple large-language-model backends behind a single
4//! async `complete()` interface.  Supported providers:
5//! - **Anthropic** (`claude-*` models, Messages API)
6//! - **OpenAI** (`gpt-*` and compatible, Chat Completions API)
7//! - **Ollama** (local, OpenAI-compatible endpoint)
8//!
9//! The active provider and model are selected via [`LlmConfig`].
10
11use anyhow::Result;
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use tokio::sync::mpsc;
16
17use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
18
19/// Supported LLM provider backends.
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
21#[serde(rename_all = "lowercase")]
22pub enum LlmProvider {
23    #[default]
24    Anthropic,
25    OpenAI,
26    Ollama,
27    /// Google Gemini (generativelanguage.googleapis.com).
28    Gemini,
29    /// NVIDIA NIM (integrate.api.nvidia.com) — OpenAI-compatible.
30    /// Free tier: ~1000 API calls/month per model.
31    Nim,
32}
33
34impl std::fmt::Display for LlmProvider {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            LlmProvider::Anthropic => write!(f, "anthropic"),
38            LlmProvider::OpenAI => write!(f, "openai"),
39            LlmProvider::Ollama => write!(f, "ollama"),
40            LlmProvider::Gemini => write!(f, "gemini"),
41            LlmProvider::Nim => write!(f, "nim"),
42        }
43    }
44}
45
46/// Per-model pricing in USD per 1M tokens.
47fn pricing(model: &str) -> (f64, f64) {
48    match model {
49        m if m.starts_with("claude-sonnet-4-6") => (3.0, 15.0),
50        m if m.starts_with("claude-haiku-4-5") => (0.8, 4.0),
51        m if m.starts_with("claude-opus-4-6") => (15.0, 75.0),
52        m if m.starts_with("gpt-4o-mini") => (0.15, 0.6),
53        m if m.starts_with("gpt-4o") => (2.5, 10.0),
54        m if m.starts_with("deepseek") => (0.27, 1.10),
55        m if m.contains("llama-3.3-70b") => (0.39, 0.39),
56        m if m.contains("llama-3.1-8b") => (0.10, 0.10),
57        m if m.starts_with("gemini-2.0-flash") => (0.10, 0.40),
58        m if m.starts_with("gemini-1.5-pro") => (1.25, 5.0),
59        _ => (1.0, 3.0),
60    }
61}
62
63/// Calculate cost in nano-USD from token counts and model name.
64pub fn calc_cost_nano_usd(model: &str, input_tokens: u64, output_tokens: u64) -> u64 {
65    let (in_price, out_price) = pricing(model);
66    let cost_usd =
67        (input_tokens as f64 * in_price + output_tokens as f64 * out_price) / 1_000_000.0;
68    (cost_usd * 1_000_000_000.0) as u64
69}
70
71/// A single turn in a conversation (role + content).
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ChatMessage {
74    pub role: String,
75    pub content: String,
76}
77
78/// Configuration for the LLM backend.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct LlmConfig {
81    pub provider: LlmProvider,
82    /// Model name, e.g. `"claude-sonnet-4-6"`, `"gpt-4o"`, `"llama3"`.
83    pub model: String,
84    /// API key (Anthropic / OpenAI). Not needed for Ollama.
85    pub api_key: Option<String>,
86    /// Base URL override (useful for Ollama or proxies).
87    pub base_url: Option<String>,
88    /// Maximum tokens to generate.
89    pub max_tokens: u32,
90    /// Sampling temperature.
91    pub temperature: f32,
92    /// Optional system prompt.
93    pub system_prompt: Option<String>,
94}
95
96impl Default for LlmConfig {
97    fn default() -> Self {
98        Self {
99            provider: LlmProvider::Anthropic,
100            model: "claude-sonnet-4-6".into(),
101            api_key: None,
102            base_url: None,
103            max_tokens: 4096,
104            temperature: 0.7,
105            system_prompt: None,
106        }
107    }
108}
109
110/// An actor that calls an LLM provider and returns completions.
111pub struct LlmAgent {
112    pub(crate) config: ActorConfig,
113    pub(crate) llm_config: LlmConfig,
114    pub(crate) http: reqwest::Client,
115    pub(crate) state: ActorState,
116    pub(crate) metrics: Arc<ActorMetrics>,
117    pub(crate) mailbox_tx: mpsc::Sender<Message>,
118    pub(crate) mailbox_rx: Option<mpsc::Receiver<Message>>,
119    /// Conversation history for multi-turn exchanges.
120    pub(crate) history: Vec<ChatMessage>,
121    pub(crate) publisher: Option<EventPublisher>,
122    /// Consecutive API errors since last success — WIK monitors this via MQTT.
123    pub(crate) consecutive_errors: u32,
124}
125
126impl LlmAgent {
127    pub fn new(config: ActorConfig, llm_config: LlmConfig) -> Self {
128        let (tx, rx) = mpsc::channel(config.mailbox_capacity);
129        Self {
130            config,
131            llm_config,
132            http: reqwest::Client::new(),
133            state: ActorState::Initializing,
134            metrics: Arc::new(ActorMetrics::new()),
135            mailbox_tx: tx,
136            mailbox_rx: Some(rx),
137            history: Vec::new(),
138            publisher: None,
139            consecutive_errors: 0,
140        }
141    }
142
143    /// Attach an EventPublisher for MQTT output.
144    pub fn with_publisher(mut self, p: EventPublisher) -> Self {
145        self.publisher = Some(p);
146        self
147    }
148
149    /// Send a prompt to the configured LLM provider and return the completion.
150    /// Also records token usage and cost in the actor metrics.
151    pub async fn complete(&self, prompt: &str) -> Result<String> {
152        let (text, input_tok, output_tok) = match self.llm_config.provider {
153            LlmProvider::Anthropic => self.complete_anthropic(prompt).await?,
154            LlmProvider::OpenAI | LlmProvider::Ollama => {
155                self.complete_openai_compat(prompt, None).await?
156            }
157            LlmProvider::Nim => {
158                let base = "https://integrate.api.nvidia.com/v1";
159                self.complete_openai_compat(prompt, Some(base)).await?
160            }
161            LlmProvider::Gemini => self.complete_gemini(prompt).await?,
162        };
163        let cost_nano = calc_cost_nano_usd(&self.llm_config.model, input_tok, output_tok);
164        self.metrics
165            .record_llm_usage(input_tok, output_tok, cost_nano);
166        Ok(text)
167    }
168
169    fn now_ms() -> u64 {
170        std::time::SystemTime::now()
171            .duration_since(std::time::UNIX_EPOCH)
172            .unwrap_or_default()
173            .as_millis() as u64
174    }
175
176    /// Publish a provider error to `system/llm/error` so WIK can react.
177    fn publish_llm_error(&self, error: &str) {
178        if let Some(pub_) = &self.publisher {
179            pub_.publish(
180                wactorz_mqtt::topics::SYSTEM_LLM_ERROR,
181                &serde_json::json!({
182                    "provider":          self.llm_config.provider.to_string(),
183                    "model":             self.llm_config.model,
184                    "error":             error,
185                    "consecutiveErrors": self.consecutive_errors + 1,
186                    "timestampMs":       Self::now_ms(),
187                }),
188            );
189        }
190    }
191
192    /// Returns `(text, input_tokens, output_tokens)`.
193    async fn complete_anthropic(&self, prompt: &str) -> Result<(String, u64, u64)> {
194        let api_key = self
195            .llm_config
196            .api_key
197            .as_deref()
198            .ok_or_else(|| anyhow::anyhow!("LLM_API_KEY not set for Anthropic"))?;
199
200        let mut messages = serde_json::json!([]);
201        for m in &self.history {
202            messages
203                .as_array_mut()
204                .unwrap()
205                .push(serde_json::json!({"role": m.role, "content": m.content}));
206        }
207        messages
208            .as_array_mut()
209            .unwrap()
210            .push(serde_json::json!({"role": "user", "content": prompt}));
211
212        let mut body = serde_json::json!({
213            "model": self.llm_config.model,
214            "max_tokens": self.llm_config.max_tokens,
215            "messages": messages,
216        });
217        if let Some(sys) = &self.llm_config.system_prompt {
218            body["system"] = serde_json::Value::String(sys.clone());
219        }
220
221        let resp = self
222            .http
223            .post("https://api.anthropic.com/v1/messages")
224            .header("x-api-key", api_key)
225            .header("anthropic-version", "2023-06-01")
226            .json(&body)
227            .send()
228            .await?;
229
230        if !resp.status().is_success() {
231            let s = resp.status();
232            let t = resp.text().await.unwrap_or_default();
233            anyhow::bail!("Anthropic {s}: {t}");
234        }
235        let json: serde_json::Value = resp.json().await?;
236        let text = json["content"][0]["text"]
237            .as_str()
238            .ok_or_else(|| anyhow::anyhow!("unexpected Anthropic response: {json}"))?
239            .to_string();
240        let input_tok = json["usage"]["input_tokens"].as_u64().unwrap_or(0);
241        let output_tok = json["usage"]["output_tokens"].as_u64().unwrap_or(0);
242        Ok((text, input_tok, output_tok))
243    }
244
245    /// Returns `(text, input_tokens, output_tokens)`.
246    async fn complete_gemini(&self, prompt: &str) -> Result<(String, u64, u64)> {
247        let api_key = self
248            .llm_config
249            .api_key
250            .as_deref()
251            .ok_or_else(|| anyhow::anyhow!("LLM_API_KEY not set for Gemini"))?;
252
253        let model = &self.llm_config.model;
254        let url = format!(
255            "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
256            model, api_key
257        );
258
259        let mut contents: Vec<serde_json::Value> = self
260            .history
261            .iter()
262            .map(|m| {
263                let role = if m.role == "assistant" {
264                    "model"
265                } else {
266                    "user"
267                };
268                serde_json::json!({ "role": role, "parts": [{ "text": m.content }] })
269            })
270            .collect();
271        contents.push(serde_json::json!({
272            "role": "user",
273            "parts": [{ "text": prompt }]
274        }));
275
276        let mut body = serde_json::json!({ "contents": contents });
277        if let Some(sys) = &self.llm_config.system_prompt {
278            body["systemInstruction"] = serde_json::json!({ "parts": [{ "text": sys }] });
279        }
280
281        let resp = self.http.post(&url).json(&body).send().await?;
282        if !resp.status().is_success() {
283            let s = resp.status();
284            let t = resp.text().await.unwrap_or_default();
285            anyhow::bail!("Gemini {s}: {t}");
286        }
287        let json: serde_json::Value = resp.json().await?;
288        let text = json["candidates"][0]["content"]["parts"][0]["text"]
289            .as_str()
290            .ok_or_else(|| anyhow::anyhow!("unexpected Gemini response: {json}"))?
291            .to_string();
292        let input_tok = json["usageMetadata"]["promptTokenCount"]
293            .as_u64()
294            .unwrap_or(0);
295        let output_tok = json["usageMetadata"]["candidatesTokenCount"]
296            .as_u64()
297            .unwrap_or(0);
298        Ok((text, input_tok, output_tok))
299    }
300
301    /// OpenAI-compatible endpoint (OpenAI, Ollama, NIM).
302    /// `base_url_override` takes precedence over `llm_config.base_url`.
303    /// Returns `(text, input_tokens, output_tokens)`.
304    async fn complete_openai_compat(
305        &self,
306        prompt: &str,
307        base_url_override: Option<&str>,
308    ) -> Result<(String, u64, u64)> {
309        let base = base_url_override
310            .or(self.llm_config.base_url.as_deref())
311            .unwrap_or("https://api.openai.com/v1");
312
313        let mut msgs = Vec::new();
314        if let Some(sys) = &self.llm_config.system_prompt {
315            msgs.push(serde_json::json!({"role": "system", "content": sys}));
316        }
317        for m in &self.history {
318            msgs.push(serde_json::json!({"role": m.role, "content": m.content}));
319        }
320        msgs.push(serde_json::json!({"role": "user", "content": prompt}));
321
322        let body = serde_json::json!({
323            "model":       self.llm_config.model,
324            "messages":    msgs,
325            "max_tokens":  self.llm_config.max_tokens,
326            "temperature": self.llm_config.temperature,
327        });
328
329        let mut req = self
330            .http
331            .post(format!("{base}/chat/completions"))
332            .json(&body);
333        if let Some(key) = &self.llm_config.api_key {
334            req = req.header("Authorization", format!("Bearer {key}"));
335        }
336        let resp = req.send().await?;
337        if !resp.status().is_success() {
338            let s = resp.status();
339            let t = resp.text().await.unwrap_or_default();
340            anyhow::bail!("OpenAI-compat {s}: {t}");
341        }
342        let json: serde_json::Value = resp.json().await?;
343        let text = json["choices"][0]["message"]["content"]
344            .as_str()
345            .ok_or_else(|| anyhow::anyhow!("unexpected response: {json}"))?
346            .to_string();
347        let input_tok = json["usage"]["prompt_tokens"].as_u64().unwrap_or(0);
348        let output_tok = json["usage"]["completion_tokens"].as_u64().unwrap_or(0);
349        Ok((text, input_tok, output_tok))
350    }
351}
352
353#[async_trait]
354impl Actor for LlmAgent {
355    fn id(&self) -> String {
356        self.config.id.clone()
357    }
358    fn name(&self) -> &str {
359        &self.config.name
360    }
361    fn state(&self) -> ActorState {
362        self.state.clone()
363    }
364    fn metrics(&self) -> Arc<ActorMetrics> {
365        Arc::clone(&self.metrics)
366    }
367    fn mailbox(&self) -> mpsc::Sender<Message> {
368        self.mailbox_tx.clone()
369    }
370    fn is_protected(&self) -> bool {
371        self.config.protected
372    }
373
374    async fn handle_message(&mut self, message: Message) -> Result<()> {
375        use wactorz_core::message::MessageType;
376
377        // ── WIK hot-swap: task_id "wik/switch" carries new provider config ──────
378        if let MessageType::Task {
379            task_id, payload, ..
380        } = &message.payload
381            && task_id == "wik/switch"
382        {
383            let provider_str = payload
384                .get("provider")
385                .and_then(|v| v.as_str())
386                .unwrap_or("");
387            let new_provider = match provider_str {
388                "anthropic" => LlmProvider::Anthropic,
389                "openai" => LlmProvider::OpenAI,
390                "gemini" => LlmProvider::Gemini,
391                "ollama" => LlmProvider::Ollama,
392                "nim" => LlmProvider::Nim,
393                other => {
394                    tracing::warn!(
395                        "[{}] wik/switch: unknown provider '{other}'",
396                        self.config.name
397                    );
398                    return Ok(());
399                }
400            };
401            let reason = payload
402                .get("reason")
403                .and_then(|v| v.as_str())
404                .unwrap_or("WIK switch");
405            tracing::info!(
406                "[{}] ⚡ provider switch: {} → {provider_str} ({reason})",
407                self.config.name,
408                self.llm_config.provider,
409            );
410            self.llm_config.provider = new_provider;
411            if let Some(model) = payload.get("model").and_then(|v| v.as_str()) {
412                self.llm_config.model = model.to_string();
413            }
414            if let Some(key) = payload.get("apiKey").and_then(|v| v.as_str()) {
415                self.llm_config.api_key = Some(key.to_string());
416            }
417            if let Some(url) = payload.get("baseUrl").and_then(|v| v.as_str()) {
418                self.llm_config.base_url = Some(url.to_string());
419            }
420            self.consecutive_errors = 0;
421            return Ok(());
422        }
423
424        let prompt = match &message.payload {
425            MessageType::Text { content } => content.clone(),
426            MessageType::Task { description, .. } => description.clone(),
427            _ => return Ok(()),
428        };
429
430        match self.complete(&prompt).await {
431            Ok(reply_text) => {
432                self.consecutive_errors = 0;
433                self.history.push(ChatMessage {
434                    role: "user".into(),
435                    content: prompt,
436                });
437                self.history.push(ChatMessage {
438                    role: "assistant".into(),
439                    content: reply_text.clone(),
440                });
441                if let Some(sender_id) = message.from {
442                    tracing::debug!(
443                        "[{}] generated reply ({} chars)",
444                        self.config.name,
445                        reply_text.len()
446                    );
447                    let reply =
448                        Message::text(Some(self.config.id.clone()), Some(sender_id), reply_text);
449                    let _ = reply;
450                }
451            }
452            Err(e) => {
453                self.consecutive_errors += 1;
454                let err_str = e.to_string();
455                tracing::error!(
456                    "[{}] LLM error (consecutive: {}) — {err_str}",
457                    self.config.name,
458                    self.consecutive_errors
459                );
460                self.publish_llm_error(&err_str);
461                return Err(e);
462            }
463        }
464        Ok(())
465    }
466
467    async fn on_heartbeat(&mut self) -> Result<()> {
468        // use std::sync::atomic::Ordering;
469        if let Some(pub_) = &self.publisher {
470            let snap = self.metrics.snapshot();
471            pub_.publish(
472                wactorz_mqtt::topics::heartbeat(&self.config.id),
473                &serde_json::json!({
474                    "agentId":         self.config.id,
475                    "agentName":       self.config.name,
476                    "state":           self.state,
477                    "provider":        self.llm_config.provider.to_string(),
478                    "model":           self.llm_config.model,
479                    "llmInputTokens":  snap.llm_input_tokens,
480                    "llmOutputTokens": snap.llm_output_tokens,
481                    "llmCostUsd":      snap.llm_cost_usd,
482                    "restartCount":    snap.restart_count,
483                    "sequence":        snap.heartbeats,
484                    "timestampMs":     std::time::SystemTime::now()
485                        .duration_since(std::time::UNIX_EPOCH)
486                        .unwrap_or_default().as_millis() as u64,
487                }),
488            );
489        }
490        Ok(())
491    }
492
493    async fn run(&mut self) -> Result<()> {
494        self.on_start().await?;
495        self.state = ActorState::Running;
496        let mut rx = self
497            .mailbox_rx
498            .take()
499            .ok_or_else(|| anyhow::anyhow!("LlmAgent already running"))?;
500        let mut hb = tokio::time::interval(std::time::Duration::from_secs(
501            self.config.heartbeat_interval_secs,
502        ));
503        hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
504        loop {
505            tokio::select! {
506                biased;
507                msg = rx.recv() => {
508                    match msg {
509                        None => break,
510                        Some(m) => {
511                            self.metrics.record_received();
512                            if let wactorz_core::message::MessageType::Command {
513                                command: wactorz_core::message::ActorCommand::Stop
514                            } = &m.payload {
515                                break;
516                            }
517                            match self.handle_message(m).await {
518                                Ok(_) => self.metrics.record_processed(),
519                                Err(e) => {
520                                    tracing::error!("[{}] {e}", self.config.name);
521                                    self.metrics.record_failed();
522                                }
523                            }
524                        }
525                    }
526                }
527                _ = hb.tick() => {
528                    self.metrics.record_heartbeat();
529                    if let Err(e) = self.on_heartbeat().await {
530                        tracing::error!("[{}] heartbeat: {e}", self.config.name);
531                    }
532                }
533            }
534        }
535        self.state = ActorState::Stopped;
536        self.on_stop().await
537    }
538}