wactorz_agents/
wik_agent.rs

1//! Key-management agent — **WIK** (Waldiez Intelligence Keys).
2//!
3//! Manages LLM provider credentials, monitors API errors, and automatically
4//! fails over to the next-priority provider when the active one hits errors
5//! or rate limits.  Also tracks per-provider usage counts for WIF integration.
6//!
7//! NATO node: **kilo** (K → Keys)
8//!
9//! ## Failover behaviour
10//!
11//! WIK listens on `system/llm/error`.  After `error_threshold` consecutive
12//! errors from the active provider it publishes `system/llm/switch` with the
13//! next-priority provider config, which LlmAgent / MainActor applies live.
14//!
15//! Default threshold: **3** consecutive errors (configurable via `set threshold`).
16//!
17//! ## Usage (via IO bar)
18//!
19//! ```text
20//! @wik-agent status                               → all providers + active
21//! @wik-agent add anthropic <key> [model]          → register/update
22//! @wik-agent add gemini <key> [model]             → register fallback
23//! @wik-agent add openai <key> [model]             → register (⚠ warnings shown)
24//! @wik-agent priority anthropic gemini openai     → set fallback order
25//! @wik-agent switch gemini [reason]               → manually activate
26//! @wik-agent test [provider]                      → ping provider API
27//! @wik-agent usage                                → call + error counts
28//! @wik-agent rotate <provider>                    → flag for key rotation
29//! @wik-agent set threshold <n>                    → errors before failover (default 3)
30//! @wik-agent help
31//! ```
32
33use anyhow::Result;
34use async_trait::async_trait;
35use std::{
36    sync::{Arc, Mutex},
37    time::{SystemTime, UNIX_EPOCH},
38};
39use tokio::sync::mpsc;
40
41use wactorz_core::{Actor, ActorConfig, ActorMetrics, ActorState, EventPublisher, Message};
42
43// ── GPT/OpenAI warning ─────────────────────────────────────────────────────────
44
45const OPENAI_WARNING: &str = "⚠ **OpenAI / GPT — read before using as fallback**\n\n\
46**Cost**\n\
47• GPT-4o is significantly more expensive than Gemini Flash or Claude Haiku\n\
48• Rate-limit overages bill automatically — set a spend cap at platform.openai.com\n\
49• _Recommendation_: use Gemini as first fallback; GPT as last-resort only\n\n\
50**Data & Privacy**\n\
51• OpenAI may use API inputs to improve models **by default** on some plans\n\
52• To opt out: platform.openai.com → Settings → Data Controls → \"Improve model for everyone\" → OFF\n\
53• Zero-retention is available on Enterprise / API tiers with a Data Processing Addendum (DPA)\n\
54• Inputs sent to OpenAI may be stored for up to 30 days for abuse monitoring\n\n\
55_Type `@wik-agent add openai <key> --confirm` to register anyway._";
56
57// ── Data model ─────────────────────────────────────────────────────────────────
58
59#[derive(Clone)]
60struct ProviderEntry {
61    name: String, // "anthropic" | "gemini" | "openai" | "ollama"
62    api_key: String,
63    model: String,
64    base_url: Option<String>,
65    priority: usize, // 1 = highest
66    call_count: u64,
67    error_count: u64,
68    rotate_flag: bool,
69    active: bool,
70}
71
72impl ProviderEntry {
73    fn default_model(name: &str) -> &'static str {
74        match name {
75            "anthropic" => "claude-sonnet-4-6",
76            "gemini" => "gemini-2.0-flash",
77            "openai" => "gpt-4o",
78            "ollama" => "llama3",
79            _ => "unknown",
80        }
81    }
82}
83
84// ── WikAgent ───────────────────────────────────────────────────────────────────
85
86pub struct WikAgent {
87    config: ActorConfig,
88    state: ActorState,
89    metrics: Arc<ActorMetrics>,
90    mailbox_tx: mpsc::Sender<Message>,
91    mailbox_rx: Option<mpsc::Receiver<Message>>,
92    publisher: Option<EventPublisher>,
93    providers: Arc<Mutex<Vec<ProviderEntry>>>,
94    /// Consecutive errors received from the current active provider.
95    consecutive_errors: Arc<Mutex<u32>>,
96    /// How many consecutive errors before triggering failover.
97    error_threshold: Arc<Mutex<u32>>,
98}
99
100impl WikAgent {
101    pub fn new(config: ActorConfig) -> Self {
102        let (tx, rx) = mpsc::channel(config.mailbox_capacity);
103        Self {
104            config,
105            state: ActorState::Initializing,
106            metrics: Arc::new(ActorMetrics::new()),
107            mailbox_tx: tx,
108            mailbox_rx: Some(rx),
109            publisher: None,
110            providers: Arc::new(Mutex::new(Vec::new())),
111            consecutive_errors: Arc::new(Mutex::new(0)),
112            error_threshold: Arc::new(Mutex::new(3)),
113        }
114    }
115
116    pub fn with_publisher(mut self, p: EventPublisher) -> Self {
117        self.publisher = Some(p);
118        self
119    }
120
121    fn now_ms() -> u64 {
122        SystemTime::now()
123            .duration_since(UNIX_EPOCH)
124            .unwrap_or_default()
125            .as_millis() as u64
126    }
127
128    fn reply(&self, content: &str) {
129        if let Some(pub_) = &self.publisher {
130            pub_.publish(
131                wactorz_mqtt::topics::chat(&self.config.id),
132                &serde_json::json!({
133                    "from":        self.config.name,
134                    "to":          "user",
135                    "content":     content,
136                    "timestampMs": Self::now_ms(),
137                }),
138            );
139        }
140    }
141
142    /// Publish `system/llm/switch` to trigger a live provider swap in LlmAgent.
143    fn publish_switch(&self, entry: &ProviderEntry, reason: &str) {
144        if let Some(pub_) = &self.publisher {
145            pub_.publish(
146                wactorz_mqtt::topics::SYSTEM_LLM_SWITCH,
147                &serde_json::json!({
148                    "provider":    entry.name,
149                    "model":       entry.model,
150                    "apiKey":      entry.api_key,
151                    "baseUrl":     entry.base_url,
152                    "reason":      reason,
153                    "timestampMs": Self::now_ms(),
154                }),
155            );
156        }
157    }
158
159    /// Called when a `system/llm/error` arrives.  Increments error count for
160    /// the active provider and triggers failover when threshold is reached.
161    fn handle_llm_error(&self, payload: &serde_json::Value) {
162        let incoming_provider = payload
163            .get("provider")
164            .and_then(|v| v.as_str())
165            .unwrap_or("");
166        let error_msg = payload
167            .get("error")
168            .and_then(|v| v.as_str())
169            .unwrap_or("unknown error");
170
171        let mut providers = self.providers.lock().unwrap();
172
173        // Update error counter for that provider
174        if let Some(entry) = providers
175            .iter_mut()
176            .find(|e| e.active && e.name == incoming_provider)
177        {
178            entry.error_count += 1;
179        }
180
181        let mut consecutive = self.consecutive_errors.lock().unwrap();
182        *consecutive += 1;
183
184        let threshold = *self.error_threshold.lock().unwrap();
185        tracing::warn!(
186            "[wik-agent] LLM error #{} from '{incoming_provider}': {error_msg}",
187            *consecutive,
188        );
189
190        if *consecutive < threshold {
191            return; // not yet — wait for more
192        }
193
194        // Threshold reached — find the next provider in priority order
195        let next = {
196            let active_priority = providers
197                .iter()
198                .find(|e| e.active)
199                .map(|e| e.priority)
200                .unwrap_or(0);
201            providers
202                .iter()
203                .filter(|e| !e.active && e.priority > active_priority && !e.api_key.is_empty())
204                .min_by_key(|e| e.priority)
205                .cloned()
206        };
207
208        match next {
209            None => {
210                tracing::error!("[wik-agent] failover: no more providers available!");
211                self.reply(&format!(
212                    "🔴 **WIK failover failed** — no more providers in queue.\n\n\
213                     `{incoming_provider}` has hit {threshold} consecutive errors.\n\n\
214                     Add a fallback provider:\n\
215                     `@wik-agent add gemini <key>`"
216                ));
217            }
218            Some(ref next_entry) => {
219                let reason = format!(
220                    "auto-failover: {incoming_provider} hit {threshold} consecutive errors"
221                );
222                tracing::info!(
223                    "[wik-agent] ⚡ failover: {incoming_provider} → {}",
224                    next_entry.name
225                );
226
227                // Mark active/inactive
228                for e in providers.iter_mut() {
229                    e.active = e.name == next_entry.name;
230                }
231                *consecutive = 0;
232                drop(providers);
233                drop(consecutive);
234
235                self.reply(&format!(
236                    "⚡ **WIK Auto-Failover**\n\n\
237                     `{incoming_provider}` → **`{}`** ({})\n\n\
238                     _{reason}_\n\n\
239                     Use `@wik-agent status` to review. `@wik-agent switch {incoming_provider}` to revert.",
240                    next_entry.name, next_entry.model,
241                ));
242                self.publish_switch(next_entry, &reason);
243            }
244        }
245    }
246
247    // ── Command handlers ────────────────────────────────────────────────────────
248
249    fn cmd_status(&self) -> String {
250        let providers = self.providers.lock().unwrap();
251        if providers.is_empty() {
252            return "📭 No providers configured.\n\n\
253                    Add one: `@wik-agent add anthropic <key>`"
254                .to_string();
255        }
256
257        let mut sorted: Vec<&ProviderEntry> = providers.iter().collect();
258        sorted.sort_by_key(|e| e.priority);
259
260        let threshold = *self.error_threshold.lock().unwrap();
261        let consecutive = *self.consecutive_errors.lock().unwrap();
262
263        let rows: Vec<String> = sorted.iter().map(|e| {
264            let active_icon = if e.active { "▶" } else { "  " };
265            let rotate_flag = if e.rotate_flag { " 🔄" } else { "" };
266            let key_hint = if e.api_key.len() > 8 {
267                format!("{}…{}", &e.api_key[..4], &e.api_key[e.api_key.len()-4..])
268            } else {
269                "••••••••".to_string()
270            };
271            format!(
272                "{active_icon} **{}** [P{}] `{}` · key: `{key_hint}` · calls: {} · errors: {}{rotate_flag}",
273                e.name, e.priority, e.model, e.call_count, e.error_count,
274            )
275        }).collect();
276
277        let active_name = sorted
278            .iter()
279            .find(|e| e.active)
280            .map(|e| e.name.as_str())
281            .unwrap_or("none");
282        format!(
283            "**🔑 WIK — Key Status**\n\n\
284             Active: **{active_name}** · threshold: {threshold} errors · consecutive now: {consecutive}\n\n{}",
285            rows.join("\n")
286        )
287    }
288
289    fn cmd_add(&self, parts: &[&str]) -> String {
290        if parts.len() < 2 {
291            return "Usage: `add <anthropic|gemini|openai|ollama> <api_key> [model]`".to_string();
292        }
293
294        let name = parts[0].to_lowercase();
295        if !["anthropic", "gemini", "openai", "ollama"].contains(&name.as_str()) {
296            return format!("Unknown provider `{name}`. Use: anthropic, gemini, openai, ollama.");
297        }
298
299        // OpenAI: require --confirm or show warning first
300        let has_confirm = parts.contains(&"--confirm");
301        if name == "openai" && !has_confirm {
302            return OPENAI_WARNING.to_string();
303        }
304
305        let api_key = parts[1].to_string();
306        let model = parts
307            .get(2)
308            .filter(|&&s| s != "--confirm")
309            .map(|&s| s.to_string())
310            .unwrap_or_else(|| ProviderEntry::default_model(&name).to_string());
311
312        let mut providers = self.providers.lock().unwrap();
313
314        if let Some(existing) = providers.iter_mut().find(|e| e.name == name) {
315            let old_key_hint = if existing.api_key.len() > 4 {
316                format!("{}…", &existing.api_key[..4])
317            } else {
318                "••••".to_string()
319            };
320            existing.api_key = api_key;
321            existing.model = model.clone();
322            return format!("🔑 Updated **{name}** (was `{old_key_hint}`) → model `{model}`");
323        }
324
325        // Assign next priority
326        let next_priority = providers.iter().map(|e| e.priority).max().unwrap_or(0) + 1;
327        let is_first = providers.is_empty();
328        providers.push(ProviderEntry {
329            name: name.clone(),
330            api_key,
331            model: model.clone(),
332            base_url: None,
333            priority: next_priority,
334            call_count: 0,
335            error_count: 0,
336            rotate_flag: false,
337            active: is_first,
338        });
339
340        let active_note = if is_first {
341            " — set as **active** (first provider)"
342        } else {
343            ""
344        };
345        format!("✅ Registered **{name}** · `{model}` [P{next_priority}]{active_note}")
346    }
347
348    fn cmd_priority(&self, parts: &[&str]) -> String {
349        if parts.is_empty() {
350            return "Usage: `priority <provider1> <provider2> …`\n\nExample: `priority anthropic gemini openai`".to_string();
351        }
352
353        let mut providers = self.providers.lock().unwrap();
354        let mut updated = Vec::new();
355
356        for (i, &name) in parts.iter().enumerate() {
357            let priority = i + 1;
358            if let Some(e) = providers.iter_mut().find(|e| e.name == name) {
359                e.priority = priority;
360                updated.push(format!("  {}. {name}", priority));
361            } else {
362                return format!(
363                    "❓ Provider `{name}` not found. Add it first with `add {name} <key>`."
364                );
365            }
366        }
367
368        format!("✅ Priority order updated:\n\n{}", updated.join("\n"))
369    }
370
371    fn cmd_switch(&self, parts: &[&str]) -> String {
372        if parts.is_empty() {
373            return "Usage: `switch <provider> [reason]`\n\nExample: `switch gemini manual override`".to_string();
374        }
375        let name = parts[0].to_lowercase();
376        let reason = if parts.len() > 1 {
377            parts[1..].join(" ")
378        } else {
379            "manual switch".to_string()
380        };
381
382        let mut providers = self.providers.lock().unwrap();
383        let found = providers.iter().any(|e| e.name == name);
384        if !found {
385            return format!("❓ Provider `{name}` not registered. Use `add {name} <key>` first.");
386        }
387
388        let target = providers.iter().find(|e| e.name == name).cloned().unwrap();
389        let prev = providers
390            .iter()
391            .find(|e| e.active)
392            .map(|e| e.name.clone())
393            .unwrap_or_else(|| "none".to_string());
394        for e in providers.iter_mut() {
395            e.active = e.name == name;
396        }
397
398        *self.consecutive_errors.lock().unwrap() = 0;
399        drop(providers);
400
401        self.publish_switch(&target, &reason);
402        format!(
403            "⚡ Switched: **{prev}** → **{name}** (`{}`)\n\n_Reason: {reason}_",
404            target.model
405        )
406    }
407
408    fn cmd_usage(&self) -> String {
409        let providers = self.providers.lock().unwrap();
410        if providers.is_empty() {
411            return "📭 No providers registered yet.".to_string();
412        }
413
414        let total_calls: u64 = providers.iter().map(|e| e.call_count).sum();
415        let total_errors: u64 = providers.iter().map(|e| e.error_count).sum();
416
417        let mut sorted: Vec<&ProviderEntry> = providers.iter().collect();
418        sorted.sort_by_key(|e| e.priority);
419
420        let rows: Vec<String> = sorted
421            .iter()
422            .map(|e| {
423                let bar = if total_calls > 0 {
424                    let frac = e.call_count as f64 / total_calls as f64;
425                    let filled = (frac * 10.0).round() as usize;
426                    format!("[{}{}]", "█".repeat(filled), "░".repeat(10 - filled))
427                } else {
428                    "[░░░░░░░░░░]".to_string()
429                };
430                let err_rate = if e.call_count > 0 {
431                    format!(
432                        "{:.1}% err",
433                        e.error_count as f64 / e.call_count as f64 * 100.0
434                    )
435                } else {
436                    "no calls".to_string()
437                };
438                format!("  **{}** {bar} {} calls · {err_rate}", e.name, e.call_count)
439            })
440            .collect();
441
442        format!(
443            "**📊 WIK Usage**\n\n{}\n\n**Total**: {} calls · {} errors\n\n\
444             _Use `@wif-agent add misc \"LLM API\" <cost>` to log spend._",
445            rows.join("\n"),
446            total_calls,
447            total_errors,
448        )
449    }
450
451    fn cmd_rotate(&self, parts: &[&str]) -> String {
452        if parts.is_empty() {
453            return "Usage: `rotate <provider>`\n\nFlags the provider key for rotation reminder."
454                .to_string();
455        }
456        let name = parts[0].to_lowercase();
457        let mut providers = self.providers.lock().unwrap();
458        match providers.iter_mut().find(|e| e.name == name) {
459            None => format!("❓ Provider `{name}` not found."),
460            Some(e) => {
461                e.rotate_flag = true;
462                format!(
463                    "🔄 **{name}** flagged for key rotation.\n\nWhen ready:\n`@wik-agent add {name} <new_key>`"
464                )
465            }
466        }
467    }
468
469    fn cmd_set(&self, parts: &[&str]) -> String {
470        if parts.len() < 2 {
471            return "Usage: `set threshold <n>`\n\nExample: `set threshold 5`".to_string();
472        }
473        match (parts[0], parts[1].parse::<u32>()) {
474            ("threshold", Ok(n)) if n > 0 => {
475                *self.error_threshold.lock().unwrap() = n;
476                format!("✅ Failover threshold set to **{n}** consecutive errors.")
477            }
478            ("threshold", _) => "Threshold must be a positive integer.".to_string(),
479            (k, _) => format!("Unknown setting `{k}`. Currently only `threshold` is supported."),
480        }
481    }
482
483    fn cmd_test(&self, parts: &[&str]) -> String {
484        let name = parts.first().copied();
485        let providers = self.providers.lock().unwrap();
486
487        let targets: Vec<&ProviderEntry> = match name {
488            Some(n) => providers.iter().filter(|e| e.name == n).collect(),
489            None => providers.iter().filter(|e| e.active).collect(),
490        };
491
492        if targets.is_empty() {
493            return "❓ No matching provider found.".to_string();
494        }
495
496        // We can't do async HTTP here (sync context), so just report config sanity.
497        let results: Vec<String> = targets
498            .iter()
499            .map(|e| {
500                let key_ok = !e.api_key.is_empty();
501                let icon = if key_ok { "✅" } else { "❌" };
502                let note = if key_ok {
503                    format!("key present · model `{}` · P{}", e.model, e.priority)
504                } else {
505                    "no API key set".to_string()
506                };
507                format!("  {icon} **{}** — {note}", e.name)
508            })
509            .collect();
510
511        format!(
512            "**🔍 WIK Provider Check**\n\n{}\n\n\
513             _Live ping not available from agent context — \
514             errors will surface via `system/llm/error` on first call._",
515            results.join("\n")
516        )
517    }
518
519    fn dispatch(&self, text: &str) -> Option<String> {
520        let arg = text.strip_prefix("@wik-agent").unwrap_or(text).trim();
521
522        // Transparently handle system/llm/error JSON payloads routed by main.rs
523        if let Ok(val) = serde_json::from_str::<serde_json::Value>(arg)
524            && (val.get("consecutiveErrors").is_some() || val.get("provider").is_some())
525        {
526            self.handle_llm_error(&val);
527            return None; // no user-visible reply for internal events
528        }
529
530        let parts: Vec<&str> = arg.split_whitespace().collect();
531        let cmd = parts.first().copied().unwrap_or("help");
532
533        Some(match cmd {
534            "status" => self.cmd_status(),
535            "add" => self.cmd_add(&parts[1..]),
536            "priority" => self.cmd_priority(&parts[1..]),
537            "switch" => self.cmd_switch(&parts[1..]),
538            "usage" => self.cmd_usage(),
539            "rotate" => self.cmd_rotate(&parts[1..]),
540            "set" => self.cmd_set(&parts[1..]),
541            "test" => self.cmd_test(&parts[1..]),
542            "help" | "" => "**WIK — Key Manager** 🔑\n\
543                 _Waldiez Intelligence Keys · NATO: kilo_\n\n\
544                 ```\n\
545                 add <anthropic|gemini|openai|ollama> <key> [model]\n\
546                 status                    all providers + active\n\
547                 priority <p1> <p2> …      set failover order\n\
548                 switch <provider> [why]   manual activate\n\
549                 test [provider]           config sanity check\n\
550                 usage                     call + error counts\n\
551                 rotate <provider>         flag key for rotation\n\
552                 set threshold <n>         errors before failover\n\
553                 help                      this message\n\
554                 ```\n\n\
555                 Auto-failover: after `threshold` consecutive errors WIK\n\
556                 switches to the next-priority provider automatically."
557                .to_string(),
558            _ => format!("Unknown command: `{cmd}`. Type `help` for the full command list."),
559        })
560    }
561}
562
563// ── Actor implementation ────────────────────────────────────────────────────────
564
565#[async_trait]
566impl Actor for WikAgent {
567    fn id(&self) -> String {
568        self.config.id.clone()
569    }
570    fn name(&self) -> &str {
571        &self.config.name
572    }
573    fn state(&self) -> ActorState {
574        self.state.clone()
575    }
576    fn metrics(&self) -> Arc<ActorMetrics> {
577        Arc::clone(&self.metrics)
578    }
579    fn mailbox(&self) -> mpsc::Sender<Message> {
580        self.mailbox_tx.clone()
581    }
582    fn is_protected(&self) -> bool {
583        self.config.protected
584    }
585
586    async fn on_start(&mut self) -> Result<()> {
587        self.state = ActorState::Running;
588        if let Some(pub_) = &self.publisher {
589            pub_.publish(
590                wactorz_mqtt::topics::spawn(&self.config.id),
591                &serde_json::json!({
592                    "agentId":     self.config.id,
593                    "agentName":   self.config.name,
594                    "agentType":   "keymaster",
595                    "timestampMs": Self::now_ms(),
596                }),
597            );
598        }
599        Ok(())
600    }
601
602    async fn handle_message(&mut self, message: Message) -> Result<()> {
603        use wactorz_core::message::MessageType;
604
605        let content = match &message.payload {
606            MessageType::Text { content } => content.trim().to_string(),
607            MessageType::Task { description, .. } => description.trim().to_string(),
608            _ => return Ok(()),
609        };
610
611        if let Some(reply) = self.dispatch(&content) {
612            self.reply(&reply);
613        }
614        Ok(())
615    }
616
617    async fn on_heartbeat(&mut self) -> Result<()> {
618        if let Some(pub_) = &self.publisher {
619            pub_.publish(
620                wactorz_mqtt::topics::heartbeat(&self.config.id),
621                &serde_json::json!({
622                    "agentId":     self.config.id,
623                    "agentName":   self.config.name,
624                    "state":       self.state,
625                    "timestampMs": Self::now_ms(),
626                }),
627            );
628        }
629        Ok(())
630    }
631
632    async fn run(&mut self) -> Result<()> {
633        self.on_start().await?;
634
635        let mut rx = self
636            .mailbox_rx
637            .take()
638            .ok_or_else(|| anyhow::anyhow!("WikAgent already running"))?;
639
640        let mut hb = tokio::time::interval(std::time::Duration::from_secs(
641            self.config.heartbeat_interval_secs,
642        ));
643        hb.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
644
645        loop {
646            tokio::select! {
647                biased;
648                msg = rx.recv() => match msg {
649                    None    => break,
650                    Some(m) => {
651                        self.metrics.record_received();
652                        if let wactorz_core::message::MessageType::Command {
653                            command: wactorz_core::message::ActorCommand::Stop,
654                        } = &m.payload { break; }
655                        match self.handle_message(m).await {
656                            Ok(_)  => self.metrics.record_processed(),
657                            Err(e) => {
658                                tracing::error!("[{}] {e}", self.config.name);
659                                self.metrics.record_failed();
660                            }
661                        }
662                    }
663                },
664                _ = hb.tick() => {
665                    self.metrics.record_heartbeat();
666                    if let Err(e) = self.on_heartbeat().await {
667                        tracing::error!("[{}] heartbeat: {e}", self.config.name);
668                    }
669                }
670            }
671        }
672
673        self.state = ActorState::Stopped;
674        self.on_stop().await
675    }
676}