use std::{
any::Any,
fmt::{Display, Formatter},
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{ready, Context, Poll},
};
use futures::{
future::{select, BoxFuture, FusedFuture, Shared},
pin_mut, Future, FutureExt, TryFutureExt,
};
use reth_tasks::{shutdown::GracefulShutdown, TaskSpawner, TaskSpawnerExt};
use tokio::{
runtime::Handle,
sync::{
mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
oneshot, OnceCell,
},
task::JoinHandle,
};
use tracing::{debug, error, Instrument};
static EXECUTOR: OnceCell<BrontesTaskExecutor> = OnceCell::const_new();
#[derive(Debug)]
#[must_use = "BrontesTaskManager must be polled to monitor critical tasks"]
pub struct BrontesTaskManager {
handle: Handle,
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
panicked_tasks_rx: UnboundedReceiver<PanickedTaskError>,
signal: Option<Signal>,
on_shutdown: Shutdown,
graceful_tasks: Arc<AtomicUsize>,
}
impl BrontesTaskManager {
pub fn current() -> Self {
let handle = Handle::current();
Self::new(handle, false)
}
pub fn new(handle: Handle, no_panic_override: bool) -> Self {
let (panicked_tasks_tx, panicked_tasks_rx) = unbounded_channel();
let (signal, on_shutdown) = signal();
let tx = panicked_tasks_tx.clone();
let bt_level = std::env::var("RUST_BACKTRACE").unwrap_or(String::from("0"));
if bt_level == "0" && !no_panic_override {
std::panic::set_hook(Box::new(move |info| {
let location = info.location().unwrap();
let msg = match info.payload().downcast_ref::<&'static str>() {
Some(s) => *s,
None => match info.payload().downcast_ref::<String>() {
Some(s) => &s[..],
None => "Box<dyn Any>",
},
};
let error_msg = format!("panic happened at {location}:\n {msg}");
let _ = tx.send(PanickedTaskError::new("thread", Box::new(error_msg)));
}));
}
let this = Self {
handle,
panicked_tasks_tx,
panicked_tasks_rx,
signal: Some(signal),
on_shutdown,
graceful_tasks: Arc::new(AtomicUsize::new(0)),
};
let _ = EXECUTOR.set(this.executor());
this
}
pub fn executor(&self) -> BrontesTaskExecutor {
BrontesTaskExecutor {
handle: self.handle.clone(),
on_shutdown: self.on_shutdown.clone(),
panicked_tasks_tx: self.panicked_tasks_tx.clone(),
graceful_tasks: Arc::clone(&self.graceful_tasks),
}
}
pub fn graceful_shutdown(self) {
let _ = self.do_graceful_shutdown(None);
}
pub fn graceful_shutdown_with_timeout(self, timeout: std::time::Duration) -> bool {
self.do_graceful_shutdown(Some(timeout))
}
fn do_graceful_shutdown(self, timeout: Option<std::time::Duration>) -> bool {
drop(self.signal);
let when = timeout.map(|t| std::time::Instant::now() + t);
while self.graceful_tasks.load(Ordering::Relaxed) > 0 {
if when
.map(|when| std::time::Instant::now() > when)
.unwrap_or(false)
{
debug!("graceful shutdown timed out");
return false
}
std::hint::spin_loop();
}
debug!("gracefully shut down");
true
}
}
impl Future for BrontesTaskManager {
type Output = PanickedTaskError;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let err = ready!(self.get_mut().panicked_tasks_rx.poll_recv(cx));
Poll::Ready(err.expect("stream can not end"))
}
}
#[derive(Debug, Clone)]
pub struct BrontesTaskExecutor {
handle: Handle,
on_shutdown: Shutdown,
panicked_tasks_tx: UnboundedSender<PanickedTaskError>,
graceful_tasks: Arc<AtomicUsize>,
}
impl BrontesTaskExecutor {
pub fn current() -> &'static Self {
EXECUTOR
.get()
.expect("not running in a brontes task manager scope")
}
pub fn trigger_shutdown(&self, task_name: &'static str) {
let _ = self
.panicked_tasks_tx
.send(PanickedTaskError { error: None, task_name });
}
pub fn handle(&self) -> &Handle {
&self.handle
}
pub fn on_shutdown_signal(&self) -> &Shutdown {
&self.on_shutdown
}
#[track_caller]
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
self.handle.block_on(future)
}
fn spawn_on_rt<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
match task_kind {
TaskKind::Default => self.handle.spawn(fut),
TaskKind::Blocking => {
let handle = self.handle.clone();
self.handle.spawn_blocking(move || handle.block_on(fut))
}
}
}
fn spawn_task_as<F>(&self, fut: F, task_kind: TaskKind) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let on_shutdown = self.on_shutdown.clone();
let task = {
async move {
pin_mut!(fut);
let _ = select(on_shutdown, fut).await;
}
}
.in_current_span();
self.spawn_on_rt(task, task_kind)
}
pub fn spawn<F>(&self, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_task_as(fut, TaskKind::Default)
}
pub fn spawn_blocking<F>(&self, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_task_as(fut, TaskKind::Blocking)
}
pub fn spawn_with_signal<F>(&self, f: impl FnOnce(Shutdown) -> F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let on_shutdown = self.on_shutdown.clone();
let fut = f(on_shutdown);
let task = fut.in_current_span();
self.handle.spawn(task)
}
fn spawn_critical_as<F>(
&self,
name: &'static str,
fut: F,
task_kind: TaskKind,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let on_shutdown = self.on_shutdown.clone();
let task = std::panic::AssertUnwindSafe(fut)
.catch_unwind()
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
})
.in_current_span();
let task = async move {
pin_mut!(task);
let _ = select(on_shutdown, task).await;
};
self.spawn_on_rt(task, task_kind)
}
pub fn spawn_critical_blocking<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_critical_as(name, fut, TaskKind::Blocking)
}
pub fn spawn_critical<F>(&self, name: &'static str, fut: F) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
self.spawn_critical_as(name, fut, TaskKind::Default)
}
pub fn spawn_critical_with_shutdown_signal<F>(
&self,
name: &'static str,
f: impl FnOnce(Shutdown) -> F,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let on_shutdown = self.on_shutdown.clone();
let fut = f(on_shutdown);
let task = std::panic::AssertUnwindSafe(fut)
.catch_unwind()
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
})
.map(|_| ())
.in_current_span();
self.handle.spawn(task)
}
pub fn get_graceful_shutdown(&self) -> GracefulShutdown {
let on_shutdown = LocalGracefulShutdown::new(
self.on_shutdown.clone(),
LocalGracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
);
unsafe { std::mem::transmute(on_shutdown) }
}
pub fn spawn_critical_with_graceful_shutdown_signal<F>(
&self,
name: &'static str,
f: impl FnOnce(GracefulShutdown) -> F,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let panicked_tasks_tx = self.panicked_tasks_tx.clone();
let on_shutdown = LocalGracefulShutdown::new(
self.on_shutdown.clone(),
LocalGracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
);
#[allow(clippy::missing_transmute_annotations)]
let fut = f(unsafe { std::mem::transmute(on_shutdown) });
let task = std::panic::AssertUnwindSafe(fut)
.catch_unwind()
.map_err(move |error| {
let task_error = PanickedTaskError::new(name, error);
error!("{task_error}");
let _ = panicked_tasks_tx.send(task_error);
})
.map(|_| ())
.in_current_span();
self.handle.spawn(task)
}
pub fn spawn_with_graceful_shutdown_signal<F>(
&self,
f: impl FnOnce(GracefulShutdown) -> F,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
let on_shutdown = LocalGracefulShutdown::new(
self.on_shutdown.clone(),
LocalGracefulShutdownGuard::new(Arc::clone(&self.graceful_tasks)),
);
#[allow(clippy::missing_transmute_annotations)]
let fut = f(unsafe { std::mem::transmute(on_shutdown) });
self.handle.spawn(fut)
}
}
impl TaskSpawner for BrontesTaskExecutor {
fn spawn(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
self.spawn(fut)
}
fn spawn_critical(&self, name: &'static str, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
BrontesTaskExecutor::spawn_critical(self, name, fut)
}
fn spawn_blocking(&self, fut: BoxFuture<'static, ()>) -> JoinHandle<()> {
self.spawn_blocking(fut)
}
fn spawn_critical_blocking(
&self,
name: &'static str,
fut: BoxFuture<'static, ()>,
) -> JoinHandle<()> {
BrontesTaskExecutor::spawn_critical_blocking(self, name, fut)
}
}
impl TaskSpawnerExt for BrontesTaskExecutor {
fn spawn_critical_with_graceful_shutdown_signal<F>(
&self,
name: &'static str,
f: impl FnOnce(GracefulShutdown) -> F,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
BrontesTaskExecutor::spawn_critical_with_graceful_shutdown_signal(self, name, f)
}
fn spawn_with_graceful_shutdown_signal<F>(
&self,
f: impl FnOnce(GracefulShutdown) -> F,
) -> JoinHandle<()>
where
F: Future<Output = ()> + Send + 'static,
{
BrontesTaskExecutor::spawn_with_graceful_shutdown_signal(self, f)
}
}
#[derive(Debug, Clone)]
pub struct Shutdown(Shared<oneshot::Receiver<()>>);
impl Future for Shutdown {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let pin = self.get_mut();
if pin.0.is_terminated() || pin.0.poll_unpin(cx).is_ready() {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct Signal(oneshot::Sender<()>);
impl Signal {
pub fn fire(self) {
let _ = self.0.send(());
}
}
pub fn signal() -> (Signal, Shutdown) {
let (sender, receiver) = oneshot::channel();
(Signal(sender), Shutdown(receiver.shared()))
}
#[derive(Debug, thiserror::Error)]
pub struct PanickedTaskError {
task_name: &'static str,
error: Option<String>,
}
impl Display for PanickedTaskError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let task_name = self.task_name;
if let Some(error) = &self.error {
write!(f, "Critical task `{task_name}` panicked: `{error}`")
} else {
write!(f, "Critical task `{task_name}` panicked")
}
}
}
impl PanickedTaskError {
fn new(task_name: &'static str, error: Box<dyn Any>) -> Self {
let error = match error.downcast::<String>() {
Ok(value) => Some(*value),
Err(error) => match error.downcast::<&str>() {
Ok(value) => Some(value.to_string()),
Err(_) => None,
},
};
Self { task_name, error }
}
}
enum TaskKind {
Default,
Blocking,
}
#[derive(Debug)]
pub struct LocalGracefulShutdown {
_shutdown: Shutdown,
_guard: Option<LocalGracefulShutdownGuard>,
}
impl LocalGracefulShutdown {
pub fn new(shutdown: Shutdown, guard: LocalGracefulShutdownGuard) -> Self {
Self { _shutdown: shutdown, _guard: Some(guard) }
}
}
#[derive(Debug)]
#[must_use = "if unused the task will not be gracefully shutdown"]
#[allow(unused)]
pub struct LocalGracefulShutdownGuard(Arc<AtomicUsize>);
impl LocalGracefulShutdownGuard {
pub(crate) fn new(counter: Arc<AtomicUsize>) -> Self {
counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self(counter)
}
}