1use anyhow::Result;
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use tokio::sync::mpsc;
16
17use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
21#[serde(rename_all = "lowercase")]
22pub enum LlmProvider {
23 #[default]
24 Anthropic,
25 OpenAI,
26 Ollama,
27 Gemini,
29 Nim,
32}
33
34impl std::fmt::Display for LlmProvider {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 LlmProvider::Anthropic => write!(f, "anthropic"),
38 LlmProvider::OpenAI => write!(f, "openai"),
39 LlmProvider::Ollama => write!(f, "ollama"),
40 LlmProvider::Gemini => write!(f, "gemini"),
41 LlmProvider::Nim => write!(f, "nim"),
42 }
43 }
44}
45
46fn pricing(model: &str) -> (f64, f64) {
48 match model {
49 m if m.starts_with("claude-sonnet-4-6") => (3.0, 15.0),
50 m if m.starts_with("claude-haiku-4-5") => (0.8, 4.0),
51 m if m.starts_with("claude-opus-4-6") => (15.0, 75.0),
52 m if m.starts_with("gpt-4o-mini") => (0.15, 0.6),
53 m if m.starts_with("gpt-4o") => (2.5, 10.0),
54 m if m.starts_with("deepseek") => (0.27, 1.10),
55 m if m.contains("llama-3.3-70b") => (0.39, 0.39),
56 m if m.contains("llama-3.1-8b") => (0.10, 0.10),
57 m if m.starts_with("gemini-2.0-flash") => (0.10, 0.40),
58 m if m.starts_with("gemini-1.5-pro") => (1.25, 5.0),
59 _ => (1.0, 3.0),
60 }
61}
62
63pub fn calc_cost_nano_usd(model: &str, input_tokens: u64, output_tokens: u64) -> u64 {
65 let (in_price, out_price) = pricing(model);
66 let cost_usd =
67 (input_tokens as f64 * in_price + output_tokens as f64 * out_price) / 1_000_000.0;
68 (cost_usd * 1_000_000_000.0) as u64
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct ChatMessage {
74 pub role: String,
75 pub content: String,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct LlmConfig {
81 pub provider: LlmProvider,
82 pub model: String,
84 pub api_key: Option<String>,
86 pub base_url: Option<String>,
88 pub max_tokens: u32,
90 pub temperature: f32,
92 pub system_prompt: Option<String>,
94}
95
96impl Default for LlmConfig {
97 fn default() -> Self {
98 Self {
99 provider: LlmProvider::Anthropic,
100 model: "claude-sonnet-4-6".into(),
101 api_key: None,
102 base_url: None,
103 max_tokens: 4096,
104 temperature: 0.7,
105 system_prompt: None,
106 }
107 }
108}
109
110pub struct LlmAgent {
112 pub(crate) config: ActorConfig,
113 pub(crate) llm_config: LlmConfig,
114 pub(crate) http: reqwest::Client,
115 pub(crate) state: ActorState,
116 pub(crate) metrics: Arc<ActorMetrics>,
117 pub(crate) mailbox_tx: mpsc::Sender<Message>,
118 pub(crate) mailbox_rx: Option<mpsc::Receiver<Message>>,
119 pub(crate) history: Vec<ChatMessage>,
121 pub(crate) publisher: Option<EventPublisher>,
122 pub(crate) consecutive_errors: u32,
124}
125
126impl LlmAgent {
127 pub fn new(config: ActorConfig, llm_config: LlmConfig) -> Self {
128 let (tx, rx) = mpsc::channel(config.mailbox_capacity);
129 Self {
130 config,
131 llm_config,
132 http: reqwest::Client::new(),
133 state: ActorState::Initializing,
134 metrics: Arc::new(ActorMetrics::new()),
135 mailbox_tx: tx,
136 mailbox_rx: Some(rx),
137 history: Vec::new(),
138 publisher: None,
139 consecutive_errors: 0,
140 }
141 }
142
143 pub fn with_publisher(mut self, p: EventPublisher) -> Self {
145 self.publisher = Some(p);
146 self
147 }
148
149 pub async fn complete(&self, prompt: &str) -> Result<String> {
152 let (text, input_tok, output_tok) = match self.llm_config.provider {
153 LlmProvider::Anthropic => self.complete_anthropic(prompt).await?,
154 LlmProvider::OpenAI | LlmProvider::Ollama => {
155 self.complete_openai_compat(prompt, None).await?
156 }
157 LlmProvider::Nim => {
158 let base = "https://integrate.api.nvidia.com/v1";
159 self.complete_openai_compat(prompt, Some(base)).await?
160 }
161 LlmProvider::Gemini => self.complete_gemini(prompt).await?,
162 };
163 let cost_nano = calc_cost_nano_usd(&self.llm_config.model, input_tok, output_tok);
164 self.metrics
165 .record_llm_usage(input_tok, output_tok, cost_nano);
166 Ok(text)
167 }
168
169 fn now_ms() -> u64 {
170 std::time::SystemTime::now()
171 .duration_since(std::time::UNIX_EPOCH)
172 .unwrap_or_default()
173 .as_millis() as u64
174 }
175
176 fn publish_llm_error(&self, error: &str) {
178 if let Some(pub_) = &self.publisher {
179 pub_.publish(
180 wactorz_mqtt::topics::SYSTEM_LLM_ERROR,
181 &serde_json::json!({
182 "provider": self.llm_config.provider.to_string(),
183 "model": self.llm_config.model,
184 "error": error,
185 "consecutiveErrors": self.consecutive_errors + 1,
186 "timestampMs": Self::now_ms(),
187 }),
188 );
189 }
190 }
191
192 async fn complete_anthropic(&self, prompt: &str) -> Result<(String, u64, u64)> {
194 let api_key = self
195 .llm_config
196 .api_key
197 .as_deref()
198 .ok_or_else(|| anyhow::anyhow!("LLM_API_KEY not set for Anthropic"))?;
199
200 let mut messages = serde_json::json!([]);
201 for m in &self.history {
202 messages
203 .as_array_mut()
204 .unwrap()
205 .push(serde_json::json!({"role": m.role, "content": m.content}));
206 }
207 messages
208 .as_array_mut()
209 .unwrap()
210 .push(serde_json::json!({"role": "user", "content": prompt}));
211
212 let mut body = serde_json::json!({
213 "model": self.llm_config.model,
214 "max_tokens": self.llm_config.max_tokens,
215 "messages": messages,
216 });
217 if let Some(sys) = &self.llm_config.system_prompt {
218 body["system"] = serde_json::Value::String(sys.clone());
219 }
220
221 let resp = self
222 .http
223 .post("https://api.anthropic.com/v1/messages")
224 .header("x-api-key", api_key)
225 .header("anthropic-version", "2023-06-01")
226 .json(&body)
227 .send()
228 .await?;
229
230 if !resp.status().is_success() {
231 let s = resp.status();
232 let t = resp.text().await.unwrap_or_default();
233 anyhow::bail!("Anthropic {s}: {t}");
234 }
235 let json: serde_json::Value = resp.json().await?;
236 let text = json["content"][0]["text"]
237 .as_str()
238 .ok_or_else(|| anyhow::anyhow!("unexpected Anthropic response: {json}"))?
239 .to_string();
240 let input_tok = json["usage"]["input_tokens"].as_u64().unwrap_or(0);
241 let output_tok = json["usage"]["output_tokens"].as_u64().unwrap_or(0);
242 Ok((text, input_tok, output_tok))
243 }
244
245 async fn complete_gemini(&self, prompt: &str) -> Result<(String, u64, u64)> {
247 let api_key = self
248 .llm_config
249 .api_key
250 .as_deref()
251 .ok_or_else(|| anyhow::anyhow!("LLM_API_KEY not set for Gemini"))?;
252
253 let model = &self.llm_config.model;
254 let url = format!(
255 "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
256 model, api_key
257 );
258
259 let mut contents: Vec<serde_json::Value> = self
260 .history
261 .iter()
262 .map(|m| {
263 let role = if m.role == "assistant" {
264 "model"
265 } else {
266 "user"
267 };
268 serde_json::json!({ "role": role, "parts": [{ "text": m.content }] })
269 })
270 .collect();
271 contents.push(serde_json::json!({
272 "role": "user",
273 "parts": [{ "text": prompt }]
274 }));
275
276 let mut body = serde_json::json!({ "contents": contents });
277 if let Some(sys) = &self.llm_config.system_prompt {
278 body["systemInstruction"] = serde_json::json!({ "parts": [{ "text": sys }] });
279 }
280
281 let resp = self.http.post(&url).json(&body).send().await?;
282 if !resp.status().is_success() {
283 let s = resp.status();
284 let t = resp.text().await.unwrap_or_default();
285 anyhow::bail!("Gemini {s}: {t}");
286 }
287 let json: serde_json::Value = resp.json().await?;
288 let text = json["candidates"][0]["content"]["parts"][0]["text"]
289 .as_str()
290 .ok_or_else(|| anyhow::anyhow!("unexpected Gemini response: {json}"))?
291 .to_string();
292 let input_tok = json["usageMetadata"]["promptTokenCount"]
293 .as_u64()
294 .unwrap_or(0);
295 let output_tok = json["usageMetadata"]["candidatesTokenCount"]
296 .as_u64()
297 .unwrap_or(0);
298 Ok((text, input_tok, output_tok))
299 }
300
301 async fn complete_openai_compat(
305 &self,
306 prompt: &str,
307 base_url_override: Option<&str>,
308 ) -> Result<(String, u64, u64)> {
309 let base = base_url_override
310 .or(self.llm_config.base_url.as_deref())
311 .unwrap_or("https://api.openai.com/v1");
312
313 let mut msgs = Vec::new();
314 if let Some(sys) = &self.llm_config.system_prompt {
315 msgs.push(serde_json::json!({"role": "system", "content": sys}));
316 }
317 for m in &self.history {
318 msgs.push(serde_json::json!({"role": m.role, "content": m.content}));
319 }
320 msgs.push(serde_json::json!({"role": "user", "content": prompt}));
321
322 let body = serde_json::json!({
323 "model": self.llm_config.model,
324 "messages": msgs,
325 "max_tokens": self.llm_config.max_tokens,
326 "temperature": self.llm_config.temperature,
327 });
328
329 let mut req = self
330 .http
331 .post(format!("{base}/chat/completions"))
332 .json(&body);
333 if let Some(key) = &self.llm_config.api_key {
334 req = req.header("Authorization", format!("Bearer {key}"));
335 }
336 let resp = req.send().await?;
337 if !resp.status().is_success() {
338 let s = resp.status();
339 let t = resp.text().await.unwrap_or_default();
340 anyhow::bail!("OpenAI-compat {s}: {t}");
341 }
342 let json: serde_json::Value = resp.json().await?;
343 let text = json["choices"][0]["message"]["content"]
344 .as_str()
345 .ok_or_else(|| anyhow::anyhow!("unexpected response: {json}"))?
346 .to_string();
347 let input_tok = json["usage"]["prompt_tokens"].as_u64().unwrap_or(0);
348 let output_tok = json["usage"]["completion_tokens"].as_u64().unwrap_or(0);
349 Ok((text, input_tok, output_tok))
350 }
351}
352
353#[async_trait]
354impl Actor for LlmAgent {
355 fn id(&self) -> String {
356 self.config.id.clone()
357 }
358 fn name(&self) -> &str {
359 &self.config.name
360 }
361 fn state(&self) -> ActorState {
362 self.state.clone()
363 }
364 fn metrics(&self) -> Arc<ActorMetrics> {
365 Arc::clone(&self.metrics)
366 }
367 fn mailbox(&self) -> mpsc::Sender<Message> {
368 self.mailbox_tx.clone()
369 }
370 fn is_protected(&self) -> bool {
371 self.config.protected
372 }
373
374 async fn handle_message(&mut self, message: Message) -> Result<()> {
375 use wactorz_core::message::MessageType;
376
377 if let MessageType::Task {
379 task_id, payload, ..
380 } = &message.payload
381 && task_id == "wik/switch"
382 {
383 let provider_str = payload
384 .get("provider")
385 .and_then(|v| v.as_str())
386 .unwrap_or("");
387 let new_provider = match provider_str {
388 "anthropic" => LlmProvider::Anthropic,
389 "openai" => LlmProvider::OpenAI,
390 "gemini" => LlmProvider::Gemini,
391 "ollama" => LlmProvider::Ollama,
392 "nim" => LlmProvider::Nim,
393 other => {
394 tracing::warn!(
395 "[{}] wik/switch: unknown provider '{other}'",
396 self.config.name
397 );
398 return Ok(());
399 }
400 };
401 let reason = payload
402 .get("reason")
403 .and_then(|v| v.as_str())
404 .unwrap_or("WIK switch");
405 tracing::info!(
406 "[{}] ⚡ provider switch: {} → {provider_str} ({reason})",
407 self.config.name,
408 self.llm_config.provider,
409 );
410 self.llm_config.provider = new_provider;
411 if let Some(model) = payload.get("model").and_then(|v| v.as_str()) {
412 self.llm_config.model = model.to_string();
413 }
414 if let Some(key) = payload.get("apiKey").and_then(|v| v.as_str()) {
415 self.llm_config.api_key = Some(key.to_string());
416 }
417 if let Some(url) = payload.get("baseUrl").and_then(|v| v.as_str()) {
418 self.llm_config.base_url = Some(url.to_string());
419 }
420 self.consecutive_errors = 0;
421 return Ok(());
422 }
423
424 let prompt = match &message.payload {
425 MessageType::Text { content } => content.clone(),
426 MessageType::Task { description, .. } => description.clone(),
427 _ => return Ok(()),
428 };
429
430 match self.complete(&prompt).await {
431 Ok(reply_text) => {
432 self.consecutive_errors = 0;
433 self.history.push(ChatMessage {
434 role: "user".into(),
435 content: prompt,
436 });
437 self.history.push(ChatMessage {
438 role: "assistant".into(),
439 content: reply_text.clone(),
440 });
441 if let Some(sender_id) = message.from {
442 tracing::debug!(
443 "[{}] generated reply ({} chars)",
444 self.config.name,
445 reply_text.len()
446 );
447 let reply =
448 Message::text(Some(self.config.id.clone()), Some(sender_id), reply_text);
449 let _ = reply;
450 }
451 }
452 Err(e) => {
453 self.consecutive_errors += 1;
454 let err_str = e.to_string();
455 tracing::error!(
456 "[{}] LLM error (consecutive: {}) — {err_str}",
457 self.config.name,
458 self.consecutive_errors
459 );
460 self.publish_llm_error(&err_str);
461 return Err(e);
462 }
463 }
464 Ok(())
465 }
466
467 async fn on_heartbeat(&mut self) -> Result<()> {
468 if let Some(pub_) = &self.publisher {
470 let snap = self.metrics.snapshot();
471 pub_.publish(
472 wactorz_mqtt::topics::heartbeat(&self.config.id),
473 &serde_json::json!({
474 "agentId": self.config.id,
475 "agentName": self.config.name,
476 "state": self.state,
477 "provider": self.llm_config.provider.to_string(),
478 "model": self.llm_config.model,
479 "llmInputTokens": snap.llm_input_tokens,
480 "llmOutputTokens": snap.llm_output_tokens,
481 "llmCostUsd": snap.llm_cost_usd,
482 "restartCount": snap.restart_count,
483 "sequence": snap.heartbeats,
484 "timestampMs": std::time::SystemTime::now()
485 .duration_since(std::time::UNIX_EPOCH)
486 .unwrap_or_default().as_millis() as u64,
487 }),
488 );
489 }
490 Ok(())
491 }
492
493 async fn run(&mut self) -> Result<()> {
494 self.on_start().await?;
495 self.state = ActorState::Running;
496 let mut rx = self
497 .mailbox_rx
498 .take()
499 .ok_or_else(|| anyhow::anyhow!("LlmAgent already running"))?;
500 let mut hb = tokio::time::interval(std::time::Duration::from_secs(
501 self.config.heartbeat_interval_secs,
502 ));
503 hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
504 loop {
505 tokio::select! {
506 biased;
507 msg = rx.recv() => {
508 match msg {
509 None => break,
510 Some(m) => {
511 self.metrics.record_received();
512 if let wactorz_core::message::MessageType::Command {
513 command: wactorz_core::message::ActorCommand::Stop
514 } = &m.payload {
515 break;
516 }
517 match self.handle_message(m).await {
518 Ok(_) => self.metrics.record_processed(),
519 Err(e) => {
520 tracing::error!("[{}] {e}", self.config.name);
521 self.metrics.record_failed();
522 }
523 }
524 }
525 }
526 }
527 _ = hb.tick() => {
528 self.metrics.record_heartbeat();
529 if let Err(e) = self.on_heartbeat().await {
530 tracing::error!("[{}] heartbeat: {e}", self.config.name);
531 }
532 }
533 }
534 }
535 self.state = ActorState::Stopped;
536 self.on_stop().await
537 }
538}