From 665273047779a5f4f7c861fbc9720b6457dd7dcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robert=20Sch=C3=BCtte?= Date: Fri, 12 Jun 2026 14:18:28 +0200 Subject: [PATCH 1/2] ChannelTask can be terminated ChannelTask can be terminated and the whole worker is closed and removed from the pool. The pool will open a new worker afterwards. One ChannelTask is only allowed per WebWorker, so termination has no effect on other tasks. --- README.md | 10 +- src/channel_task.rs | 78 +++++++++++-- src/error.rs | 12 ++ src/lib.rs | 1 + src/pool/mod.rs | 249 ++++++++++++++++++++++++++++++---------- src/webworker/worker.rs | 18 ++- test/src/channel.rs | 139 +++++++++++++++++++++- test/src/lib.rs | 4 + 8 files changed, 434 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index cd7a301..b7c3b9b 100644 --- a/README.md +++ b/README.md @@ -201,9 +201,14 @@ let progress: Progress = task.recv().await.unwrap(); task.send(&Continue { should_continue: true }); // Wait for task completion -let result = task.result().await; +let result = task.result().await?; ``` +Pool channel tasks exclusively lease one worker until `result()` completes. Calling +`terminate()`, or dropping an unfinished task, terminates that worker and replaces it +in the same pool slot. The slot is unavailable to the scheduler until its replacement +has initialized. + ### Bundler support (Vite) The recommended approach for Vite is to place the wasm-pack output in Vite's `publicDir`. This keeps the glue code and WASM binary as static assets, which is required because each @@ -271,6 +276,9 @@ options.precompile_wasm = Some(true); init_worker_pool(options).await.unwrap(); ``` +The pool retains this compiled module and also uses it when replacing terminated +workers, so replacement does not fetch the WASM binary again. + ### Idle timeout Workers can be automatically terminated after a period of inactivity and transparently recreated when new tasks arrive. This is useful for freeing resources in applications where worker usage is intermittent: diff --git a/src/channel_task.rs b/src/channel_task.rs index 23f26f2..168afc3 100644 --- a/src/channel_task.rs +++ b/src/channel_task.rs @@ -3,7 +3,9 @@ use std::marker::PhantomData; use serde::{de::DeserializeOwned, Serialize}; use tokio::sync::oneshot; -use crate::{channel::Channel, convert::from_bytes}; +use crate::{channel::Channel, convert::from_bytes, error::TaskError}; + +type LifecycleCallback = Box; /// A handle to a running channel task on a WebWorker. /// @@ -24,11 +26,13 @@ use crate::{channel::Channel, convert::from_bytes}; /// let progress: Progress = task.recv().await.expect("progress"); /// task.send(&Continue { should_continue: true }); /// -/// let result: ProcessResult = task.result().await; +/// let result: ProcessResult = task.result().await.expect("worker terminated"); /// ``` pub struct ChannelTask { channel: Channel, - result_rx: oneshot::Receiver>, + result_rx: Option>>, + on_complete: Option, + on_terminate: Option, _phantom: PhantomData, } @@ -36,13 +40,35 @@ impl ChannelTask { /// Create a new `ChannelTask` from a channel and a result receiver. #[doc(hidden)] pub fn new(channel: Channel, result_rx: oneshot::Receiver>) -> Self { + Self::with_lifecycle(channel, result_rx, None, None) + } + + #[doc(hidden)] + pub(crate) fn with_lifecycle( + channel: Channel, + result_rx: oneshot::Receiver>, + on_complete: Option, + on_terminate: Option, + ) -> Self { Self { channel, - result_rx, + result_rx: Some(result_rx), + on_complete, + on_terminate, _phantom: PhantomData, } } + pub(crate) fn with_callbacks( + mut self, + on_complete: LifecycleCallback, + on_terminate: LifecycleCallback, + ) -> Self { + self.on_complete = Some(on_complete); + self.on_terminate = Some(on_terminate); + self + } + /// Receive the next deserialized message from the worker. /// /// Returns `None` if the channel's sender side has been dropped @@ -69,11 +95,45 @@ impl ChannelTask { } /// Await the task's final result, consuming the `ChannelTask`. - pub async fn result(self) -> R { - let bytes = self + pub async fn result(mut self) -> Result { + let result_rx = self .result_rx - .await - .expect("WebWorker result sender dropped"); - from_bytes(&bytes) + .take() + .ok_or(TaskError::ResultAlreadyConsumed)?; + let result = result_rx.await.map_err(|_| TaskError::WorkerTerminated); + + match result { + Ok(bytes) => { + self.on_terminate.take(); + if let Some(on_complete) = self.on_complete.take() { + on_complete(); + } + Ok(from_bytes(&bytes)) + } + Err(error) => { + if let Some(on_terminate) = self.on_terminate.take() { + on_terminate(); + } + Err(error) + } + } + } + + /// Terminate the worker running this task. + /// + /// Pool tasks exclusively lease their worker. The pool replaces the terminated + /// worker in the same slot before making that slot schedulable again. + pub fn terminate(mut self) { + if let Some(on_terminate) = self.on_terminate.take() { + on_terminate(); + } + } +} + +impl Drop for ChannelTask { + fn drop(&mut self) { + if let Some(on_terminate) = self.on_terminate.take() { + on_terminate(); + } } } diff --git a/src/error.rs b/src/error.rs index 1219e9e..50fd5a8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,6 +8,18 @@ use thiserror::Error; #[error("WebWorker capacity reached")] pub struct Full; +/// This error is returned when a channel task cannot produce a result. +#[derive(Debug, Error, PartialEq, Eq)] +#[non_exhaustive] +pub enum TaskError { + /// The worker running the task was terminated before returning a result. + #[error("WebWorker was terminated")] + WorkerTerminated, + /// The channel task result was already consumed. + #[error("ChannelTask result already consumed")] + ResultAlreadyConsumed, +} + /// This error is returned during the creation of a new web worker. /// It covers generic errors in the actual creation and import errors /// during the initialization. diff --git a/src/lib.rs b/src/lib.rs index c18254e..3441568 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,6 +97,7 @@ #![allow(clippy::borrowed_box)] pub use channel::Channel; pub use channel_task::ChannelTask; +pub use error::TaskError; pub use global::{ has_worker_pool, init_optimized_worker_pool, init_worker_pool, worker_pool, AlreadyInitialized, }; diff --git a/src/pool/mod.rs b/src/pool/mod.rs index d2675a7..541e0e0 100644 --- a/src/pool/mod.rs +++ b/src/pool/mod.rs @@ -118,11 +118,19 @@ impl WorkerPoolOptions { /// The state of a single worker slot in the pool. enum WorkerSlot { /// Worker is active and can accept tasks. - Active(WebWorker), + Active { + worker: Rc, + generation: u64, + }, + /// Worker is exclusively leased to a channel task. + Leased { + worker: Rc, + generation: u64, + }, /// Worker is being created (prevents duplicate creation during async init). - Creating, + Creating { generation: u64 }, /// Worker was terminated by idle timeout and can be recreated. - Empty, + Empty { generation: u64 }, } pub struct WebWorkerPool { @@ -143,7 +151,7 @@ pub struct WebWorkerPool { /// Idle checker interval ID (for clearInterval on Drop). _idle_checker_id: Option, /// Notify waiting tasks when a worker becomes available after creation. - worker_ready: tokio::sync::Notify, + worker_ready: Rc, } impl Drop for WebWorkerPool { @@ -225,7 +233,12 @@ impl WebWorkerPool { let slots: Rc>> = Rc::new( workers .into_iter() - .map(|w| RefCell::new(WorkerSlot::Active(w))) + .map(|worker| { + RefCell::new(WorkerSlot::Active { + worker: Rc::new(worker), + generation: 0, + }) + }) .collect(), ); @@ -237,11 +250,16 @@ impl WebWorkerPool { for i in 0..slots_clone.len() { let should_terminate = { let s = slots_clone[i].borrow(); - matches!(&*s, WorkerSlot::Active(ref w) - if w.current_load() == 0 && (now - w.last_active()) >= timeout as f64) + matches!(&*s, WorkerSlot::Active { worker, .. } + if worker.current_load() == 0 + && (now - worker.last_active()) >= timeout as f64) }; if should_terminate { - *slots_clone[i].borrow_mut() = WorkerSlot::Empty; + let generation = match &*slots_clone[i].borrow() { + WorkerSlot::Active { generation, .. } => generation + 1, + _ => continue, + }; + *slots_clone[i].borrow_mut() = WorkerSlot::Empty { generation }; } } }); @@ -266,7 +284,7 @@ impl WebWorkerPool { pool_path_bg: options.path_bg.clone(), _idle_checker_cb: idle_checker_cb, _idle_checker_id: idle_checker_id, - worker_ready: tokio::sync::Notify::new(), + worker_ready: Rc::new(tokio::sync::Notify::new()), }) } @@ -306,7 +324,7 @@ impl WebWorkerPool { /// /// let progress: Progress = task.recv().await.expect("progress"); /// task.send(&Continue { should_continue: true }); - /// let result: ProcessResult = task.result().await; + /// let result: ProcessResult = task.result().await.expect("worker terminated"); /// ``` pub async fn run_channel(&self, func: WebWorkerChannelFn, arg: &T) -> ChannelTask where @@ -338,45 +356,58 @@ impl WebWorkerPool { /// Acquires an active worker slot, recreating a terminated worker if needed. async fn acquire_worker(&self) -> usize { loop { + let notified = self.worker_ready.notified(); let loads = self.compute_loads(); if let Some(id) = self.scheduler.schedule(&loads) { return id; } - // No active workers. Find first Empty slot and recreate. - let empty_slot = self - .slots - .iter() - .position(|slot| matches!(&*slot.borrow(), WorkerSlot::Empty)); - if let Some(i) = empty_slot { - *self.slots[i].borrow_mut() = WorkerSlot::Creating; + if self.recreate_empty_worker().await { + continue; } - if let Some(slot_id) = empty_slot { - let worker_result = WebWorker::with_path_and_module( - self.pool_path.as_deref(), - self.pool_path_bg.as_deref(), - None, - self.wasm_module.clone(), - ) - .await; - match worker_result { - Ok(worker) => { - *self.slots[slot_id].borrow_mut() = WorkerSlot::Active(worker); - self.worker_ready.notify_waiters(); - return slot_id; - } - Err(_) => { - *self.slots[slot_id].borrow_mut() = WorkerSlot::Empty; - self.worker_ready.notify_waiters(); - panic!("Couldn't recreate worker"); - } - } - } + // All slots are leased, busy, or being created. + notified.await; + } + } - // All slots are Creating — wait for one to finish. - self.worker_ready.notified().await; + /// Recreate one empty worker slot. Returns whether an empty slot was found. + async fn recreate_empty_worker(&self) -> bool { + let empty_slot = + self.slots + .iter() + .enumerate() + .find_map(|(i, slot)| match &*slot.borrow() { + WorkerSlot::Empty { generation } => Some((i, *generation)), + _ => None, + }); + let Some((slot_id, generation)) = empty_slot else { + return false; + }; + + *self.slots[slot_id].borrow_mut() = WorkerSlot::Creating { generation }; + let worker_result = WebWorker::with_path_and_module( + self.pool_path.as_deref(), + self.pool_path_bg.as_deref(), + None, + self.wasm_module.clone(), + ) + .await; + match worker_result { + Ok(worker) => { + *self.slots[slot_id].borrow_mut() = WorkerSlot::Active { + worker: Rc::new(worker), + generation, + }; + self.worker_ready.notify_waiters(); + } + Err(_) => { + *self.slots[slot_id].borrow_mut() = WorkerSlot::Empty { generation }; + self.worker_ready.notify_waiters(); + panic!("Couldn't recreate worker"); + } } + true } /// Compute per-slot loads for the scheduler. @@ -384,17 +415,48 @@ impl WebWorkerPool { self.slots .iter() .map(|slot| match &*slot.borrow() { - WorkerSlot::Active(w) => Some(w.current_load()), + WorkerSlot::Active { worker, .. } => Some(worker.current_load()), _ => None, }) .collect() } + /// Acquires an idle worker and exclusively leases its slot to a channel task. + async fn acquire_channel_worker(&self) -> (usize, Rc, u64) { + loop { + let notified = self.worker_ready.notified(); + let loads = self + .slots + .iter() + .map(|slot| match &*slot.borrow() { + WorkerSlot::Active { worker, .. } if worker.current_load() == 0 => Some(0), + _ => None, + }) + .collect::>(); + + if let Some(slot_id) = self.scheduler.schedule(&loads) { + let mut slot = self.slots[slot_id].borrow_mut(); + if let WorkerSlot::Active { worker, generation } = &*slot { + let worker = Rc::clone(worker); + let generation = *generation; + *slot = WorkerSlot::Leased { + worker: Rc::clone(&worker), + generation, + }; + return (slot_id, worker, generation); + } + } + + if self.recreate_empty_worker().await { + continue; + } + + notified.await; + } + } + /// Determines the worker to run a simple task on using the scheduler /// and runs the task. - // Per-slot RefCell: holding a borrow across await is safe because - // the idle checker only terminates slots with zero load (i.e., not borrowed). - #[allow(clippy::await_holding_refcell_ref)] pub(crate) async fn run_internal(&self, func: WebWorkerFn, arg: A) -> R where A: Borrow, @@ -402,18 +464,17 @@ impl WebWorkerPool { R: Serialize + for<'de> Deserialize<'de>, { let worker_id = self.acquire_worker().await; - let slot = self.slots[worker_id].borrow(); - match &*slot { - WorkerSlot::Active(worker) => worker.run_internal(func, arg.borrow()).await, + let worker = match &*self.slots[worker_id].borrow() { + WorkerSlot::Active { worker, .. } => Rc::clone(worker), _ => unreachable!("acquire_worker guarantees Active slot"), - } + }; + let result = worker.run_internal(func, arg.borrow()).await; + self.worker_ready.notify_waiters(); + result } /// Determines the worker to run a channel task on using the scheduler /// and runs the task. - // Per-slot RefCell: holding a borrow across await is safe because - // the idle checker only terminates slots with zero load (i.e., not borrowed). - #[allow(clippy::await_holding_refcell_ref)] pub(crate) async fn run_channel_internal( &self, func: WebWorkerChannelFn, @@ -423,12 +484,73 @@ impl WebWorkerPool { T: Serialize + for<'de> Deserialize<'de>, R: Serialize + for<'de> Deserialize<'de>, { - let worker_id = self.acquire_worker().await; - let slot = self.slots[worker_id].borrow(); - match &*slot { - WorkerSlot::Active(worker) => worker.run_channel_internal(func, arg).await, - _ => unreachable!("acquire_worker guarantees Active slot"), - } + let (worker_id, worker, generation) = self.acquire_channel_worker().await; + let task = worker.run_channel_internal(func, arg).await; + + let release_slots = Rc::clone(&self.slots); + let release_ready = Rc::clone(&self.worker_ready); + let release_worker = Rc::clone(&worker); + let on_complete = Box::new(move || { + let mut slot = release_slots[worker_id].borrow_mut(); + if matches!(&*slot, WorkerSlot::Leased { generation: current, .. } if *current == generation) + { + *slot = WorkerSlot::Active { + worker: release_worker, + generation, + }; + release_ready.notify_waiters(); + } + }); + + let terminate_slots = Rc::clone(&self.slots); + let terminate_ready = Rc::clone(&self.worker_ready); + let terminate_path = self.pool_path.clone(); + let terminate_path_bg = self.pool_path_bg.clone(); + let terminate_module = self.wasm_module.clone(); + let on_terminate = Box::new(move || { + let replacement_generation = generation + 1; + { + let mut slot = terminate_slots[worker_id].borrow_mut(); + if !matches!(&*slot, WorkerSlot::Leased { generation: current, .. } if *current == generation) + { + return; + } + *slot = WorkerSlot::Creating { + generation: replacement_generation, + }; + } + + worker.terminate(); + let slots = Rc::clone(&terminate_slots); + let ready = Rc::clone(&terminate_ready); + wasm_bindgen_futures::spawn_local(async move { + let replacement = WebWorker::with_path_and_module( + terminate_path.as_deref(), + terminate_path_bg.as_deref(), + None, + terminate_module, + ) + .await; + + let mut slot = slots[worker_id].borrow_mut(); + if !matches!(&*slot, WorkerSlot::Creating { generation } if *generation == replacement_generation) + { + return; + } + *slot = match replacement { + Ok(worker) => WorkerSlot::Active { + worker: Rc::new(worker), + generation: replacement_generation, + }, + Err(_) => WorkerSlot::Empty { + generation: replacement_generation, + }, + }; + ready.notify_waiters(); + }); + }); + + task.with_callbacks(on_complete, on_terminate) } /// Return the number of tasks currently queued to this worker pool. @@ -436,8 +558,10 @@ impl WebWorkerPool { self.slots .iter() .map(|slot| match &*slot.borrow() { - WorkerSlot::Active(w) => w.current_load(), - _ => 0, + WorkerSlot::Active { worker, .. } | WorkerSlot::Leased { worker, .. } => { + worker.current_load() + } + WorkerSlot::Creating { .. } | WorkerSlot::Empty { .. } => 0, }) .sum() } @@ -451,7 +575,12 @@ impl WebWorkerPool { pub fn num_active_workers(&self) -> usize { self.slots .iter() - .filter(|s| matches!(&*RefCell::borrow(s), WorkerSlot::Active(_))) + .filter(|s| { + matches!( + &*RefCell::borrow(s), + WorkerSlot::Active { .. } | WorkerSlot::Leased { .. } + ) + }) .count() } diff --git a/src/webworker/worker.rs b/src/webworker/worker.rs index 0258010..572b182 100644 --- a/src/webworker/worker.rs +++ b/src/webworker/worker.rs @@ -263,7 +263,7 @@ impl WebWorker { /// /// let progress: Progress = task.recv().await.expect("progress"); /// task.send(&Continue { should_continue: true }); - /// let result: ProcessResult = task.result().await; + /// let result: ProcessResult = task.result().await.expect("worker terminated"); /// ``` pub async fn run_channel(&self, func: WebWorkerChannelFn, arg: &T) -> ChannelTask where @@ -406,7 +406,14 @@ impl WebWorker { // Send the request and get a receiver for the result bytes. let result_rx = self.send_channel_request(func.name, arg, worker_port); - ChannelTask::new(channel, result_rx) + let worker = self.worker.clone(); + let open_tasks = Rc::clone(&self.open_tasks); + let on_terminate = Box::new(move || { + worker.terminate(); + open_tasks.borrow_mut().clear(); + }); + + ChannelTask::with_lifecycle(channel, result_rx, None, Some(on_terminate)) } /// This function handles the communication with the worker @@ -530,10 +537,17 @@ impl WebWorker { pub fn last_active(&self) -> f64 { self.last_active.get() } + + /// Terminate this worker and fail all tasks currently assigned to it. + pub(crate) fn terminate(&self) { + self.worker.terminate(); + self.open_tasks.borrow_mut().clear(); + } } impl Drop for WebWorker { fn drop(&mut self) { self.worker.terminate(); + self.open_tasks.borrow_mut().clear(); } } diff --git a/test/src/channel.rs b/test/src/channel.rs index 8b42883..b1b4236 100644 --- a/test/src/channel.rs +++ b/test/src/channel.rs @@ -1,8 +1,23 @@ use serde::{Deserialize, Serialize}; +use wasm_bindgen::UnwrapThrowExt; +use wasm_bindgen_futures::JsFuture; use wasmworker::webworker_channel_fn; -use wasmworker::{webworker_channel, worker_pool, Channel, WebWorker}; +use wasmworker::{ + webworker, webworker_channel, worker_pool, Channel, TaskError, WebWorker, WebWorkerPool, + WorkerPoolOptions, +}; -use crate::js_assert_eq; +use crate::{js_assert_eq, raw::sort}; + +async fn sleep_ms(ms: u32) { + let promise = js_sys::Promise::new(&mut |resolve, _| { + web_sys::window() + .unwrap_throw() + .set_timeout_with_callback_and_timeout_and_arguments_0(&resolve, ms as i32) + .unwrap_throw(); + }); + JsFuture::from(promise).await.unwrap_throw(); +} /// Progress message sent from worker to main thread. #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -82,7 +97,7 @@ pub(crate) async fn can_use_channel_with_worker() { js_assert_eq!(final_progress.percent, 100, "Should be at 100%"); // Now wait for the task result - let result = task.result().await; + let result = task.result().await.expect("Channel task should succeed"); js_assert_eq!(result.items_processed, 10, "Should process all items"); js_assert_eq!(result.was_cancelled, false, "Should not be cancelled"); } @@ -107,11 +122,29 @@ pub(crate) async fn can_cancel_channel_task() { }); // Wait for result (no 100% progress expected since we cancelled) - let result = task.result().await; + let result = task.result().await.expect("Channel task should succeed"); js_assert_eq!(result.items_processed, 5, "Should process half the items"); js_assert_eq!(result.was_cancelled, true, "Should be cancelled"); } +/// Test that worker termination is reported through the channel task result. +pub(crate) async fn channel_task_reports_worker_termination() { + let worker = WebWorker::new(None).await.expect("Couldn't create worker"); + let data = vec![1, 2, 3, 4]; + let task = worker + .run_channel(webworker_channel!(process_with_progress), &data) + .await; + let _: Progress = task.recv().await.expect("Should receive progress"); + + drop(worker); + let was_terminated = matches!(task.result().await, Err(TaskError::WorkerTerminated)); + js_assert_eq!( + was_terminated, + true, + "Terminated worker should fail its channel task" + ); +} + /// Test that channel functions work with the worker pool. pub(crate) async fn can_use_channel_with_pool() { let pool = worker_pool().await; @@ -136,7 +169,103 @@ pub(crate) async fn can_use_channel_with_pool() { js_assert_eq!(final_progress.percent, 100, "Should be at 100%"); // Wait for completion - let result = task.result().await; + let result = task.result().await.expect("Channel task should succeed"); js_assert_eq!(result.items_processed, 4, "Should process all items"); js_assert_eq!(result.was_cancelled, false, "Should not be cancelled"); } + +/// Test that a pool channel task exclusively leases its worker. +pub(crate) async fn channel_task_exclusively_leases_worker() { + let pool = std::rc::Rc::new( + WebWorkerPool::with_num_workers(1) + .await + .expect("Couldn't create worker pool"), + ); + let data = vec![1, 2, 3, 4]; + let task = pool + .run_channel(webworker_channel!(process_with_progress), &data) + .await; + let _: Progress = task.recv().await.expect("Should receive progress"); + + let queued_pool = std::rc::Rc::clone(&pool); + let queued_result = std::rc::Rc::new(std::cell::RefCell::new(None)); + let task_result = std::rc::Rc::clone(&queued_result); + wasm_bindgen_futures::spawn_local(async move { + let input: Box<[u8]> = vec![3, 1, 2].into(); + let result = queued_pool.run_bytes(webworker!(sort), &input).await; + *task_result.borrow_mut() = Some(result); + }); + + sleep_ms(50).await; + js_assert_eq!( + queued_result.borrow().is_none(), + true, + "Ordinary task should wait for channel lease" + ); + + task.send(&Continue { + should_continue: true, + }); + let _ = task + .result() + .await + .expect("Channel task should release its worker"); + while queued_result.borrow().is_none() { + sleep_ms(10).await; + } + let sorted = queued_result + .borrow_mut() + .take() + .expect("Queued task should complete"); + js_assert_eq!(sorted, Box::<[u8]>::from([1, 2, 3])); +} + +/// Test that explicit termination replaces the leased worker before reuse. +pub(crate) async fn terminating_channel_task_replaces_worker() { + let mut options = WorkerPoolOptions::new(); + options.num_workers = Some(1); + options.precompile_wasm = Some(true); + let pool = WebWorkerPool::with_options(options) + .await + .expect("Couldn't create worker pool"); + let data = vec![1, 2, 3, 4]; + let task = pool + .run_channel(webworker_channel!(process_with_progress), &data) + .await; + let _: Progress = task.recv().await.expect("Should receive progress"); + + task.terminate(); + js_assert_eq!( + pool.num_active_workers(), + 0, + "Terminated slot should not be schedulable during replacement" + ); + + let input: Box<[u8]> = vec![3, 1, 2].into(); + let sorted = pool.run_bytes(webworker!(sort), &input).await; + js_assert_eq!(sorted, Box::<[u8]>::from([1, 2, 3])); + js_assert_eq!(pool.num_active_workers(), 1); +} + +/// Test that dropping an unfinished channel task also replaces its worker. +pub(crate) async fn dropping_channel_task_replaces_worker() { + let pool = WebWorkerPool::with_num_workers(1) + .await + .expect("Couldn't create worker pool"); + let data = vec![1, 2, 3, 4]; + let task = pool + .run_channel(webworker_channel!(process_with_progress), &data) + .await; + let _: Progress = task.recv().await.expect("Should receive progress"); + + drop(task); + js_assert_eq!( + pool.num_active_workers(), + 0, + "Dropped task should make its slot unavailable during replacement" + ); + + let input: Box<[u8]> = vec![3, 1, 2].into(); + let sorted = pool.run_bytes(webworker!(sort), &input).await; + js_assert_eq!(sorted, Box::<[u8]>::from([1, 2, 3])); +} diff --git a/test/src/lib.rs b/test/src/lib.rs index 554aa65..96e0252 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -38,7 +38,11 @@ pub async fn run_tests() { // Channel tests can_use_channel_with_worker().await; can_cancel_channel_task().await; + channel_task_reports_worker_termination().await; can_use_channel_with_pool().await; + channel_task_exclusively_leases_worker().await; + terminating_channel_task_replaces_worker().await; + dropping_channel_task_replaces_worker().await; // Pool configuration tests can_use_precompiled_wasm().await; From 36ef785748ce1afbafd1d75112c450237c557bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Robert=20Sch=C3=BCtte?= Date: Fri, 12 Jun 2026 14:43:46 +0200 Subject: [PATCH 2/2] Termination is always possible Termination only needs read var access, because the controlling moved to a cloneable struct. This allows to terminate the whole webworker from any place at any time. --- README.md | 17 ++++++ src/channel_task.rs | 123 +++++++++++++++++++++++++++++++++++--------- src/lib.rs | 2 +- test/src/channel.rs | 91 +++++++++++++++++++++++++++++++- test/src/lib.rs | 3 ++ 5 files changed, 211 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index b7c3b9b..1eda1aa 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,23 @@ task.send(&Continue { should_continue: true }); let result = task.result().await?; ``` +To cancel from another future, clone the task controller before moving the task: + +```rust,ignore +let control = task.control(); +wasm_bindgen_futures::spawn_local(async move { + while let Some(progress) = task.recv::().await { + // Handle progress. + } +}); + +control.terminate(); +``` + +Termination wakes blocked `recv()` and `recv_bytes()` calls, which return `None`. +Blocked `result()` calls return `TaskError::WorkerTerminated`. Repeated calls to +`terminate()` are harmless. + Pool channel tasks exclusively lease one worker until `result()` completes. Calling `terminate()`, or dropping an unfinished task, terminates that worker and replaces it in the same pool slot. The slot is unavailable to the scheduler until its replacement diff --git a/src/channel_task.rs b/src/channel_task.rs index 168afc3..319441f 100644 --- a/src/channel_task.rs +++ b/src/channel_task.rs @@ -1,7 +1,12 @@ -use std::marker::PhantomData; +use std::{ + cell::{Cell, RefCell}, + marker::PhantomData, + rc::Rc, +}; +use futures::{future::select, pin_mut}; use serde::{de::DeserializeOwned, Serialize}; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, watch}; use crate::{channel::Channel, convert::from_bytes, error::TaskError}; @@ -31,11 +36,70 @@ type LifecycleCallback = Box; pub struct ChannelTask { channel: Channel, result_rx: Option>>, + control: ChannelTaskControl, on_complete: Option, - on_terminate: Option, _phantom: PhantomData, } +/// A cloneable handle for terminating a running [`ChannelTask`]. +#[derive(Clone)] +pub struct ChannelTaskControl { + inner: Rc, +} + +struct ChannelTaskControlInner { + terminated: Cell, + on_terminate: RefCell>, + close_tx: watch::Sender, +} + +impl ChannelTaskControl { + fn new(on_terminate: Option) -> Self { + let (close_tx, _) = watch::channel(false); + Self { + inner: Rc::new(ChannelTaskControlInner { + terminated: Cell::new(false), + on_terminate: RefCell::new(on_terminate), + close_tx, + }), + } + } + + /// Terminate the worker running the associated channel task. + /// + /// Repeated calls are harmless. + pub fn terminate(&self) { + if self.inner.terminated.replace(true) { + return; + } + + let _ = self.inner.close_tx.send(true); + if let Some(callback) = self.inner.on_terminate.borrow_mut().take() { + callback(); + } + } + + fn subscribe(&self) -> watch::Receiver { + self.inner.close_tx.subscribe() + } + + fn is_terminated(&self) -> bool { + self.inner.terminated.get() + } + + fn is_armed(&self) -> bool { + self.inner.on_terminate.borrow().is_some() + } + + fn disarm(&self) { + self.inner.on_terminate.borrow_mut().take(); + } + + fn set_on_terminate(&self, callback: LifecycleCallback) { + *self.inner.on_terminate.borrow_mut() = Some(callback); + } +} + impl ChannelTask { /// Create a new `ChannelTask` from a channel and a result receiver. #[doc(hidden)] @@ -53,8 +117,8 @@ impl ChannelTask { Self { channel, result_rx: Some(result_rx), + control: ChannelTaskControl::new(on_terminate), on_complete, - on_terminate, _phantom: PhantomData, } } @@ -65,23 +129,40 @@ impl ChannelTask { on_terminate: LifecycleCallback, ) -> Self { self.on_complete = Some(on_complete); - self.on_terminate = Some(on_terminate); + self.control.set_on_terminate(on_terminate); self } + /// Return a cloneable controller that can terminate this task externally. + pub fn control(&self) -> ChannelTaskControl { + self.control.clone() + } + /// Receive the next deserialized message from the worker. /// - /// Returns `None` if the channel's sender side has been dropped - /// (i.e., the worker has finished and closed the channel). + /// Returns `None` if the channel closes or the task is terminated. pub async fn recv(&self) -> Option { - self.channel.recv().await + let bytes = self.recv_bytes().await?; + Some(from_bytes(&bytes)) } /// Receive raw bytes from the worker. /// - /// Returns `None` if the channel's sender side has been dropped. + /// Returns `None` if the channel closes or the task is terminated. pub async fn recv_bytes(&self) -> Option> { - self.channel.recv_bytes().await + if self.control.is_terminated() { + return None; + } + + let mut close_rx = self.control.subscribe(); + let message = self.channel.recv_bytes(); + let closed = close_rx.changed(); + pin_mut!(message, closed); + + match select(message, closed).await { + futures::future::Either::Left((message, _)) if !self.control.is_terminated() => message, + _ => None, + } } /// Send a serialized message to the worker. @@ -103,18 +184,16 @@ impl ChannelTask { let result = result_rx.await.map_err(|_| TaskError::WorkerTerminated); match result { - Ok(bytes) => { - self.on_terminate.take(); + Ok(bytes) if !self.control.is_terminated() => { + self.control.disarm(); if let Some(on_complete) = self.on_complete.take() { on_complete(); } Ok(from_bytes(&bytes)) } - Err(error) => { - if let Some(on_terminate) = self.on_terminate.take() { - on_terminate(); - } - Err(error) + Ok(_) | Err(_) => { + self.control.terminate(); + Err(TaskError::WorkerTerminated) } } } @@ -123,17 +202,15 @@ impl ChannelTask { /// /// Pool tasks exclusively lease their worker. The pool replaces the terminated /// worker in the same slot before making that slot schedulable again. - pub fn terminate(mut self) { - if let Some(on_terminate) = self.on_terminate.take() { - on_terminate(); - } + pub fn terminate(&self) { + self.control.terminate(); } } impl Drop for ChannelTask { fn drop(&mut self) { - if let Some(on_terminate) = self.on_terminate.take() { - on_terminate(); + if self.control.is_armed() { + self.control.terminate(); } } } diff --git a/src/lib.rs b/src/lib.rs index 3441568..d28a2f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,7 +96,7 @@ #![allow(clippy::borrowed_box)] pub use channel::Channel; -pub use channel_task::ChannelTask; +pub use channel_task::{ChannelTask, ChannelTaskControl}; pub use error::TaskError; pub use global::{ has_worker_pool, init_optimized_worker_pool, init_worker_pool, worker_pool, AlreadyInitialized, diff --git a/test/src/channel.rs b/test/src/channel.rs index b1b4236..94d334d 100644 --- a/test/src/channel.rs +++ b/test/src/channel.rs @@ -145,6 +145,68 @@ pub(crate) async fn channel_task_reports_worker_termination() { ); } +/// Test that external termination wakes a task blocked in recv(). +pub(crate) async fn external_termination_wakes_channel_recv() { + let pool = WebWorkerPool::with_num_workers(1) + .await + .expect("Couldn't create worker pool"); + let data = vec![1, 2, 3, 4]; + let task = pool + .run_channel(webworker_channel!(process_with_progress), &data) + .await; + let _: Progress = task.recv().await.expect("Should receive progress"); + let control = task.control(); + + let recv_woke = std::rc::Rc::new(std::cell::Cell::new(false)); + let recv_woke_task = std::rc::Rc::clone(&recv_woke); + wasm_bindgen_futures::spawn_local(async move { + recv_woke_task.set(task.recv::().await.is_none()); + }); + + sleep_ms(10).await; + control.terminate(); + while !recv_woke.get() { + sleep_ms(10).await; + } + + let input: Box<[u8]> = vec![3, 1, 2].into(); + let sorted = pool.run_bytes(webworker!(sort), &input).await; + js_assert_eq!(sorted, Box::<[u8]>::from([1, 2, 3])); +} + +/// Test that external termination wakes a task blocked in result(). +pub(crate) async fn external_termination_wakes_channel_result() { + let pool = WebWorkerPool::with_num_workers(1) + .await + .expect("Couldn't create worker pool"); + let data = vec![1, 2, 3, 4]; + let task = pool + .run_channel(webworker_channel!(process_with_progress), &data) + .await; + let _: Progress = task.recv().await.expect("Should receive progress"); + let control = task.control(); + + let result_woke = std::rc::Rc::new(std::cell::Cell::new(false)); + let result_woke_task = std::rc::Rc::clone(&result_woke); + wasm_bindgen_futures::spawn_local(async move { + result_woke_task.set(matches!( + task.result().await, + Err(TaskError::WorkerTerminated) + )); + }); + + sleep_ms(10).await; + control.terminate(); + control.terminate(); + while !result_woke.get() { + sleep_ms(10).await; + } + + let input: Box<[u8]> = vec![3, 1, 2].into(); + let sorted = pool.run_bytes(webworker!(sort), &input).await; + js_assert_eq!(sorted, Box::<[u8]>::from([1, 2, 3])); +} + /// Test that channel functions work with the worker pool. pub(crate) async fn can_use_channel_with_pool() { let pool = worker_pool().await; @@ -174,6 +236,29 @@ pub(crate) async fn can_use_channel_with_pool() { js_assert_eq!(result.was_cancelled, false, "Should not be cancelled"); } +/// Test that successful completion disarms termination before dropping the task. +pub(crate) async fn successful_channel_result_keeps_worker_active() { + let pool = WebWorkerPool::with_num_workers(1) + .await + .expect("Couldn't create worker pool"); + let data = vec![1, 2, 3, 4]; + let task = pool + .run_channel(webworker_channel!(process_with_progress), &data) + .await; + let _: Progress = task.recv().await.expect("Should receive progress"); + task.send(&Continue { + should_continue: true, + }); + let _: Progress = task.recv().await.expect("Should receive final progress"); + let _ = task.result().await.expect("Channel task should succeed"); + + js_assert_eq!( + pool.num_active_workers(), + 1, + "Successful completion should not terminate its worker" + ); +} + /// Test that a pool channel task exclusively leases its worker. pub(crate) async fn channel_task_exclusively_leases_worker() { let pool = std::rc::Rc::new( @@ -234,12 +319,16 @@ pub(crate) async fn terminating_channel_task_replaces_worker() { .await; let _: Progress = task.recv().await.expect("Should receive progress"); - task.terminate(); + let control = task.control(); + control.terminate(); + control.terminate(); js_assert_eq!( pool.num_active_workers(), 0, "Terminated slot should not be schedulable during replacement" ); + let was_terminated = matches!(task.result().await, Err(TaskError::WorkerTerminated)); + js_assert_eq!(was_terminated, true); let input: Box<[u8]> = vec![3, 1, 2].into(); let sorted = pool.run_bytes(webworker!(sort), &input).await; diff --git a/test/src/lib.rs b/test/src/lib.rs index 96e0252..57ee186 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -39,7 +39,10 @@ pub async fn run_tests() { can_use_channel_with_worker().await; can_cancel_channel_task().await; channel_task_reports_worker_termination().await; + external_termination_wakes_channel_recv().await; + external_termination_wakes_channel_result().await; can_use_channel_with_pool().await; + successful_channel_result_keeps_worker_active().await; channel_task_exclusively_leases_worker().await; terminating_channel_task_replaces_worker().await; dropping_channel_task_replaces_worker().await;