1use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10
11use serde::{Deserialize, Serialize};
12use tokio::sync::{RwLock, mpsc};
13
14use crate::actor::{Actor, ActorState};
15use crate::message::{ActorCommand, Message};
16use crate::metrics::ActorMetrics;
17use crate::publish::EventPublisher;
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
27#[serde(rename_all = "snake_case")]
28pub enum SupervisorStrategy {
29 #[default]
30 OneForOne,
31 OneForAll,
32 RestForOne,
33}
34
35pub type ActorFactory = Arc<dyn Fn() -> Box<dyn Actor> + Send + Sync + 'static>;
37
38struct SpecEntry {
39 factory: ActorFactory,
40 strategy: SupervisorStrategy,
41 max_restarts: u32,
42 restart_window: Duration,
43 restart_delay: Duration,
44 actor_id: Option<String>,
46 restart_times: Vec<Instant>,
48 stopped: bool,
50}
51
52impl SpecEntry {
53 fn record_restart(&mut self) -> bool {
55 let now = Instant::now();
56 self.restart_times
57 .retain(|t| now.duration_since(*t) < self.restart_window);
58 self.restart_times.push(now);
59 (self.restart_times.len() as u32) <= self.max_restarts
60 }
61
62 fn exhausted(&self) -> bool {
63 let now = Instant::now();
64 let recent = self
65 .restart_times
66 .iter()
67 .filter(|t| now.duration_since(**t) < self.restart_window)
68 .count();
69 (recent as u32) >= self.max_restarts
70 }
71}
72
73#[derive(Debug, Clone)]
77pub struct ActorEntry {
78 pub id: String,
79 pub name: String,
80 pub state: ActorState,
81 pub mailbox: mpsc::Sender<Message>,
82 pub protected: bool,
84 pub metrics: Arc<ActorMetrics>,
86 pub supervisor_id: Option<String>,
88}
89
90#[derive(Debug, Default, Clone)]
92pub struct ActorRegistry {
93 actors: Arc<RwLock<HashMap<String, ActorEntry>>>,
94}
95
96impl ActorRegistry {
97 pub fn new() -> Self {
98 Self::default()
99 }
100
101 pub async fn register(&self, entry: ActorEntry) {
103 let mut map = self.actors.write().await;
104 map.insert(entry.id.clone(), entry);
105 }
106
107 pub async fn deregister(&self, id: &str) {
109 let mut map = self.actors.write().await;
110 map.remove(id);
111 }
112
113 pub async fn get(&self, id: &str) -> Option<ActorEntry> {
115 let map = self.actors.read().await;
116 map.get(id).cloned()
117 }
118
119 pub async fn get_by_name(&self, name: &str) -> Option<ActorEntry> {
121 let map = self.actors.read().await;
122 map.values().find(|e| e.name == name).cloned()
123 }
124
125 pub async fn list(&self) -> Vec<ActorEntry> {
127 let map = self.actors.read().await;
128 map.values().cloned().collect()
129 }
130
131 pub async fn update_state(&self, id: &str, state: ActorState) {
133 let mut map = self.actors.write().await;
134 if let Some(entry) = map.get_mut(id) {
135 entry.state = state;
136 }
137 }
138
139 pub async fn send(&self, id: &str, message: Message) -> anyhow::Result<()> {
141 let map = self.actors.read().await;
142 let entry = map
143 .get(id)
144 .ok_or_else(|| anyhow::anyhow!("actor {id} not found"))?;
145 entry
146 .mailbox
147 .send(message)
148 .await
149 .map_err(|e| anyhow::anyhow!("mailbox full or closed: {e}"))
150 }
151
152 pub async fn broadcast(&self, message: Message) {
154 let map = self.actors.read().await;
155 for entry in map.values() {
156 let _ = entry.mailbox.send(message.clone()).await;
157 }
158 }
159}
160
161#[derive(Debug, Clone)]
165pub struct ActorSystem {
166 pub registry: ActorRegistry,
167 publisher: EventPublisher,
168}
169
170impl ActorSystem {
171 pub fn new() -> Self {
172 let (publisher, _rx) = EventPublisher::channel();
173 Self {
174 registry: ActorRegistry::new(),
175 publisher,
176 }
177 }
178
179 pub fn with_publisher(publisher: EventPublisher) -> Self {
180 Self {
181 registry: ActorRegistry::new(),
182 publisher,
183 }
184 }
185
186 pub fn publisher(&self) -> EventPublisher {
187 self.publisher.clone()
188 }
189
190 fn _inject_fn(&self) -> impl Fn(ActorEntry) -> ActorEntry + '_ {
191 |e| e }
193
194 pub async fn spawn_actor(&self, actor: Box<dyn Actor>) -> anyhow::Result<String> {
196 self.spawn_actor_supervised(actor, None).await
197 }
198
199 pub async fn spawn_actor_supervised(
201 &self,
202 actor: Box<dyn Actor>,
203 supervisor_id: Option<String>,
204 ) -> anyhow::Result<String> {
205 let id = actor.id();
206 let name = actor.name().to_string();
207 let mailbox = actor.mailbox();
208 let protected = actor.is_protected();
209 let metrics = actor.metrics();
210
211 let entry = ActorEntry {
212 id: id.clone(),
213 name: name.clone(),
214 state: ActorState::Initializing,
215 mailbox,
216 protected,
217 metrics,
218 supervisor_id,
219 };
220 self.registry.register(entry).await;
221
222 let registry = self.registry.clone();
223 let id_task = id.clone();
224 tokio::spawn(async move {
225 let mut actor = actor;
226 registry.update_state(&id_task, ActorState::Running).await;
227 match actor.run().await {
228 Ok(_) => registry.update_state(&id_task, ActorState::Stopped).await,
229 Err(e) => {
230 tracing::error!("[{}] run error: {e}", id_task);
231 registry
232 .update_state(&id_task, ActorState::Failed(e.to_string()))
233 .await;
234 }
235 }
236 registry.deregister(&id_task).await;
237 tracing::info!("Actor {name} ({id_task}) stopped");
238 });
239 Ok(id)
240 }
241
242 pub async fn stop_actor(&self, name: &str) -> anyhow::Result<()> {
244 let entry = self
245 .registry
246 .get_by_name(name)
247 .await
248 .ok_or_else(|| anyhow::anyhow!("actor '{name}' not found"))?;
249 if entry.protected {
250 anyhow::bail!("actor '{name}' is protected");
251 }
252 self.registry
253 .send(
254 &entry.id,
255 Message::command(entry.id.clone(), ActorCommand::Stop),
256 )
257 .await
258 }
259
260 pub async fn shutdown(&self) -> anyhow::Result<()> {
262 let actors = self.registry.list().await;
263 for entry in actors {
264 if !entry.protected {
265 let _ = self
266 .registry
267 .send(
268 &entry.id,
269 Message::command(entry.id.clone(), ActorCommand::Stop),
270 )
271 .await;
272 }
273 }
274 Ok(())
275 }
276}
277
278impl Default for ActorSystem {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284pub struct Supervisor {
302 system: ActorSystem,
303 specs: Arc<Mutex<Vec<(String, SpecEntry)>>>,
304 poll_interval: Duration,
305 watch_task: Option<tokio::task::JoinHandle<()>>,
306}
307
308impl Supervisor {
309 pub fn new(system: ActorSystem) -> Self {
310 Self {
311 system,
312 specs: Arc::new(Mutex::new(Vec::new())),
313 poll_interval: Duration::from_secs(2),
314 watch_task: None,
315 }
316 }
317
318 pub fn with_poll_interval(system: ActorSystem, poll_interval: Duration) -> Self {
319 Self {
320 system,
321 specs: Arc::new(Mutex::new(Vec::new())),
322 poll_interval,
323 watch_task: None,
324 }
325 }
326
327 pub fn supervise(
329 &mut self,
330 name: impl Into<String>,
331 factory: ActorFactory,
332 strategy: SupervisorStrategy,
333 max_restarts: u32,
334 restart_window_secs: f64,
335 restart_delay_secs: f64,
336 ) -> &mut Self {
337 let entry = SpecEntry {
338 factory,
339 strategy,
340 max_restarts,
341 restart_window: Duration::from_secs_f64(restart_window_secs),
342 restart_delay: Duration::from_secs_f64(restart_delay_secs),
343 actor_id: None,
344 restart_times: Vec::new(),
345 stopped: false,
346 };
347 self.specs.lock().unwrap().push((name.into(), entry));
348 self
349 }
350
351 pub async fn start(&mut self) -> anyhow::Result<()> {
353 let sup_id = format!("supervisor-{}", uuid::Uuid::new_v4());
354
355 let tasks: Vec<(String, ActorFactory)> = {
358 let specs = self.specs.lock().unwrap();
359 specs
360 .iter()
361 .map(|(name, e)| (name.clone(), Arc::clone(&e.factory)))
362 .collect()
363 };
364
365 for (name, factory) in &tasks {
366 let actor = factory();
367 let actor_id = self
368 .system
369 .spawn_actor_supervised(actor, Some(sup_id.clone()))
370 .await?;
371 {
372 let mut specs = self.specs.lock().unwrap();
373 if let Some((_, entry)) = specs.iter_mut().find(|(n, _)| n == name) {
374 entry.actor_id = Some(actor_id);
375 }
376 }
377 tracing::info!("[Supervisor] Spawned '{name}'");
378 }
379
380 let specs_clone = Arc::clone(&self.specs);
382 let system_clone = self.system.clone();
383 let poll = self.poll_interval;
384 let sup_id_clone = sup_id.clone();
385
386 self.watch_task = Some(tokio::spawn(async move {
387 loop {
388 tokio::time::sleep(poll).await;
389 watch_once(&system_clone, &specs_clone, &sup_id_clone).await;
390 }
391 }));
392
393 tracing::info!("[Supervisor] Started — supervising {} actors", {
394 self.specs.lock().unwrap().len()
395 });
396 Ok(())
397 }
398
399 pub async fn stop(&mut self) {
401 if let Some(task) = self.watch_task.take() {
402 task.abort();
403 }
404 let actor_ids: Vec<(String, Option<String>)> = {
407 let mut specs = self.specs.lock().unwrap();
408 specs
409 .iter_mut()
410 .map(|(name, entry)| {
411 entry.stopped = true;
412 (name.clone(), entry.actor_id.clone())
413 })
414 .collect()
415 };
416 for (name, actor_id) in actor_ids {
417 if let Some(id) = actor_id {
418 let _ = self
419 .system
420 .registry
421 .send(&id, Message::command(id.clone(), ActorCommand::Stop))
422 .await;
423 }
424 tracing::debug!("[Supervisor] Requested stop for '{name}'");
425 }
426 }
427
428 pub fn status(&self) -> Vec<serde_json::Value> {
430 let specs = self.specs.lock().unwrap();
431 specs
432 .iter()
433 .map(|(name, e)| {
434 let now = Instant::now();
435 let recent = e
436 .restart_times
437 .iter()
438 .filter(|t| now.duration_since(**t) < e.restart_window)
439 .count();
440 serde_json::json!({
441 "name": name,
442 "strategy": format!("{:?}", e.strategy),
443 "max_restarts": e.max_restarts,
444 "restarts_used": recent,
445 "exhausted": e.exhausted(),
446 "actor_id": e.actor_id,
447 })
448 })
449 .collect()
450 }
451}
452
453async fn watch_once(system: &ActorSystem, specs: &Mutex<Vec<(String, SpecEntry)>>, sup_id: &str) {
456 let failed: Vec<String> = {
458 let specs_guard = specs.lock().unwrap();
459 let mut out = Vec::new();
460 for (name, entry) in specs_guard.iter() {
461 if entry.stopped {
462 continue;
463 }
464 let is_dead = match &entry.actor_id {
465 None => true,
466 Some(_id) => {
467 false }
472 };
473 let _ = is_dead; out.push(name.clone()); }
476 out
477 };
478
479 let mut truly_failed: Vec<String> = Vec::new();
481 for name in &failed {
482 let actor_id_opt = {
483 let specs_guard = specs.lock().unwrap();
484 specs_guard
485 .iter()
486 .find(|(n, _)| n == name)
487 .and_then(|(_, e)| e.actor_id.clone())
488 };
489 let dead = match actor_id_opt {
490 None => true,
491 Some(ref id) => match system.registry.get(id).await {
492 None => true, Some(e) => matches!(e.state, ActorState::Failed(_)),
494 },
495 };
496 let stopped = specs
498 .lock()
499 .unwrap()
500 .iter()
501 .find(|(n, _)| n == name)
502 .map(|(_, e)| e.stopped)
503 .unwrap_or(true);
504 if dead && !stopped {
505 truly_failed.push(name.clone());
506 }
507 }
508
509 if truly_failed.is_empty() {
510 return;
511 }
512
513 for crashed_name in &truly_failed {
514 let strategy = {
515 let specs_guard = specs.lock().unwrap();
516 specs_guard
517 .iter()
518 .find(|(n, _)| n == crashed_name)
519 .map(|(_, e)| e.strategy.clone())
520 .unwrap_or(SupervisorStrategy::OneForOne)
521 };
522
523 tracing::warn!(
524 "[Supervisor] '{crashed_name}' failed — applying {:?} strategy.",
525 strategy
526 );
527
528 match strategy {
529 SupervisorStrategy::OneForOne => {
530 restart_one(system, specs, crashed_name, sup_id).await;
531 }
532 SupervisorStrategy::OneForAll => {
533 let all_names: Vec<String> = specs
535 .lock()
536 .unwrap()
537 .iter()
538 .map(|(n, _)| n.clone())
539 .collect();
540 for name in all_names.iter().rev() {
541 if name != crashed_name {
542 stop_one(system, specs, name).await;
543 }
544 }
545 for name in &all_names {
546 restart_one(system, specs, name, sup_id).await;
547 }
548 }
549 SupervisorStrategy::RestForOne => {
550 let all_names: Vec<String> = specs
551 .lock()
552 .unwrap()
553 .iter()
554 .map(|(n, _)| n.clone())
555 .collect();
556 let idx = all_names
557 .iter()
558 .position(|n| n == crashed_name)
559 .unwrap_or(0);
560 let affected: Vec<String> = all_names[idx..].to_vec();
561 for name in affected.iter().rev() {
562 if name != crashed_name {
563 stop_one(system, specs, name).await;
564 }
565 }
566 for name in &affected {
567 restart_one(system, specs, name, sup_id).await;
568 }
569 }
570 }
571 }
572}
573
574async fn stop_one(system: &ActorSystem, specs: &Mutex<Vec<(String, SpecEntry)>>, name: &str) {
575 let actor_id = specs
576 .lock()
577 .unwrap()
578 .iter()
579 .find(|(n, _)| n == name)
580 .and_then(|(_, e)| e.actor_id.clone());
581
582 if let Some(id) = actor_id {
583 if system.registry.get(&id).await.is_some() {
588 let _ = system
589 .registry
590 .send(&id, Message::command(id.clone(), ActorCommand::Stop))
591 .await;
592 tokio::time::sleep(Duration::from_millis(200)).await;
594 }
595 }
596 let mut specs_guard = specs.lock().unwrap();
598 if let Some((_, entry)) = specs_guard.iter_mut().find(|(n, _)| n == name) {
599 entry.actor_id = None;
600 }
601}
602
603async fn restart_one(
604 system: &ActorSystem,
605 specs: &Mutex<Vec<(String, SpecEntry)>>,
606 name: &str,
607 sup_id: &str,
608) {
609 let (_exhausted, delay, within_budget, factory) = {
610 let mut specs_guard = specs.lock().unwrap();
611 let Some((_, entry)) = specs_guard.iter_mut().find(|(n, _)| n == name) else {
612 return;
613 };
614 if entry.exhausted() {
615 tracing::error!(
616 "[Supervisor] '{name}' exhausted restart budget ({} restarts). Giving up.",
617 entry.max_restarts
618 );
619 return;
620 }
621 let budget_ok = entry.record_restart();
622 (
623 false,
624 entry.restart_delay,
625 budget_ok,
626 Arc::clone(&entry.factory),
627 )
628 };
629
630 if !within_budget {
631 return;
632 }
633
634 stop_one(system, specs, name).await;
636
637 if delay > Duration::ZERO {
638 tokio::time::sleep(delay).await;
639 }
640
641 let restart_count = {
642 let specs_guard = specs.lock().unwrap();
643 specs_guard
644 .iter()
645 .find(|(n, _)| n == name)
646 .map(|(_, e)| e.restart_times.len() as u64)
647 .unwrap_or(0)
648 };
649
650 let actor = factory();
651 match system
652 .spawn_actor_supervised(actor, Some(sup_id.to_string()))
653 .await
654 {
655 Ok(new_id) => {
656 if let Some(entry) = system.registry.get(&new_id).await {
658 entry
659 .metrics
660 .restart_count
661 .store(restart_count, std::sync::atomic::Ordering::Relaxed);
662 }
663 let mut specs_guard = specs.lock().unwrap();
664 if let Some((_, e)) = specs_guard.iter_mut().find(|(n, _)| n == name) {
665 e.actor_id = Some(new_id);
666 }
667 tracing::info!("[Supervisor] '{name}' restarted (#{restart_count}).");
668 }
669 Err(e) => {
670 tracing::error!("[Supervisor] Failed to restart '{name}': {e}");
671 }
672 }
673}