From d6b7f51d767577ebaac468eb016a1ef7407610dc Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 12 May 2026 17:24:54 +0800 Subject: [PATCH] feat: start env in parallel Signed-off-by: Ruihang Xia --- sqlness/src/config.rs | 9 ++ sqlness/src/runner.rs | 197 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 183 insertions(+), 23 deletions(-) diff --git a/sqlness/src/config.rs b/sqlness/src/config.rs index 4dec738..48b6c1e 100644 --- a/sqlness/src/config.rs +++ b/sqlness/src/config.rs @@ -41,6 +41,11 @@ pub struct Config { /// Default value: 1 #[builder(default = "Config::default_parallelism()")] pub parallelism: usize, + /// Number of environments to run in parallel. + /// + /// Default value: 1 + #[builder(default = "Config::default_env_parallelism()")] + pub env_parallelism: usize, } impl Config { @@ -83,6 +88,10 @@ impl Config { fn default_parallelism() -> usize { 1 } + + fn default_env_parallelism() -> usize { + 1 + } } /// Config for DatabaseBuilder diff --git a/sqlness/src/runner.rs b/sqlness/src/runner.rs index f9fc833..8a8936c 100644 --- a/sqlness/src/runner.rs +++ b/sqlness/src/runner.rs @@ -7,6 +7,8 @@ use std::str::FromStr; use std::sync::{Arc, Mutex}; use std::time::Instant; +use futures::stream::FuturesUnordered; +use futures::StreamExt; use prettydiff::basic::{DiffOp, SliceChangeset}; use prettydiff::diff_lines; use regex::Regex; @@ -48,39 +50,42 @@ impl Runner { let environments = self.collect_env()?; let mut errors = Vec::new(); let filter = Regex::new(&self.config.env_filter)?; + let mut runnable_environments = Vec::new(); + for env in environments { - if !filter.is_match(&env) { - println!("Environment({env}) is skipped!"); - continue; - } - let env_config = self.read_env_config(&env); - let config_path = env_config.as_path(); - let config_path = if config_path.exists() { - Some(config_path) + if filter.is_match(&env) { + runnable_environments.push(env); } else { - None - }; - let parallelism = self.config.parallelism.max(1); - let mut databases = Vec::with_capacity(parallelism); - println!("Creating enviroment with parallelism: {}", parallelism); - for id in 0..parallelism { - let db = self.env_controller.start(&env, id, config_path).await; - databases.push(db); - } - let run_result = self.run_env(&env, &databases).await; - for db in databases { - self.env_controller.stop(&env, db).await; + println!("Environment({env}) is skipped!"); } + } + let mut runnable_environments = runnable_environments.into_iter(); + let mut running = FuturesUnordered::new(); + for env in runnable_environments + .by_ref() + .take(self.config.env_parallelism.max(1)) + { + running.push(self.run_single_env(env)); + } + + let mut stop_scheduling = false; + while let Some((env, run_result)) = running.next().await { if let Err(e) = run_result { println!("Environment {env} run failed, error:{e:?}."); if self.config.fail_fast { - return Err(e); + stop_scheduling = true; } errors.push(e); } + + if !stop_scheduling { + if let Some(env) = runnable_environments.next() { + running.push(self.run_single_env(env)); + } + } } // only return first error @@ -91,6 +96,34 @@ impl Runner { Ok(()) } + async fn run_single_env(&self, env: String) -> (String, Result<()>) { + let env_config = self.read_env_config(&env); + let config_path = env_config.as_path(); + let config_path = if config_path.exists() { + Some(config_path) + } else { + None + }; + let parallelism = self.config.parallelism.max(1); + println!("Creating environment {env} with parallelism: {parallelism}"); + + let databases = futures::future::join_all( + (0..parallelism).map(|id| self.env_controller.start(&env, id, config_path)), + ) + .await; + + let run_result = self.run_env(&env, &databases).await; + + futures::future::join_all( + databases + .into_iter() + .map(|db| self.env_controller.stop(&env, db)), + ) + .await; + + (env, run_result) + } + fn read_env_config(&self, env: &str) -> PathBuf { let mut path_buf = std::path::PathBuf::new(); path_buf.push(&self.config.case_dir); @@ -157,10 +190,16 @@ impl Runner { db_idx, case_name, e ); if fail_fast { - errors.lock().expect("Failed to acquire lock on errors").push((case_name, e)); + errors + .lock() + .expect("Failed to acquire lock on errors") + .push((case_name, e)); return; } - errors.lock().expect("Failed to acquire lock on errors").push((case_name, e)); + errors + .lock() + .expect("Failed to acquire lock on errors") + .push((case_name, e)); } } } @@ -291,3 +330,115 @@ impl Runner { None } } + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::fmt::Display; + use std::fs; + use std::path::Path; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use async_trait::async_trait; + + use crate::case::QueryContext; + use crate::config::ConfigBuilder; + use crate::{Database, EnvController, Runner}; + + #[derive(Clone)] + struct RecordingEnvController { + state: Arc>, + } + + #[derive(Default)] + struct RecordingState { + active_starts: usize, + max_active_starts: usize, + active_envs: HashSet, + max_active_envs: usize, + } + + struct RecordingDb { + env: String, + id: usize, + } + + #[async_trait] + impl EnvController for RecordingEnvController { + type DB = RecordingDb; + + async fn start(&self, env: &str, id: usize, _config: Option<&Path>) -> Self::DB { + { + let mut state = self.state.lock().unwrap(); + state.active_starts += 1; + state.max_active_starts = state.max_active_starts.max(state.active_starts); + + if id == 0 { + state.active_envs.insert(env.to_string()); + state.max_active_envs = state.max_active_envs.max(state.active_envs.len()); + } + } + + tokio::time::sleep(Duration::from_millis(50)).await; + + { + let mut state = self.state.lock().unwrap(); + state.active_starts -= 1; + } + + RecordingDb { + env: env.to_string(), + id, + } + } + + async fn stop(&self, _env: &str, database: Self::DB) { + if database.id == 0 { + self.state.lock().unwrap().active_envs.remove(&database.env); + } + } + } + + #[async_trait] + impl Database for RecordingDb { + async fn query(&self, _context: QueryContext, _query: String) -> Box { + Box::new("") + } + } + + #[tokio::test] + async fn starts_databases_and_environments_in_parallel() { + let case_dir = new_case_dir(&["env_a", "env_b"]); + let state = Arc::new(Mutex::new(RecordingState::default())); + let controller = RecordingEnvController { + state: state.clone(), + }; + let config = ConfigBuilder::default() + .case_dir(case_dir.to_string_lossy().to_string()) + .parallelism(2) + .env_parallelism(2) + .build() + .unwrap(); + + Runner::new(config, controller).run().await.unwrap(); + + let state = state.lock().unwrap(); + assert_eq!(state.max_active_envs, 2); + assert_eq!(state.max_active_starts, 4); + + fs::remove_dir_all(case_dir).unwrap(); + } + + fn new_case_dir(environments: &[&str]) -> std::path::PathBuf { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + let dir = std::env::temp_dir().join(format!("sqlness-runner-test-{now}")); + for environment in environments { + fs::create_dir_all(dir.join(environment)).unwrap(); + } + dir + } +}