diff --git a/Cargo.lock b/Cargo.lock index 59988f4..c51c2e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -358,6 +358,7 @@ dependencies = [ "bon", "byte-unit", "cgroups-rs", + "futures-lite", "rand 0.10.1", "rstest", "serde", diff --git a/Cargo.toml b/Cargo.toml index 889bec3..272ae11 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ A code runner library for online judge system bon = "3.8.1" byte-unit = "5.2.0" cgroups-rs = "0.5.0" +futures-lite = "2.6.1" serde = { version = "1.0.228", features = ["derive"], optional = true } state-shift = "2.1.1" tokio = { version = "1.48.0", features = ["full"] } diff --git a/src/judge.rs b/src/judge.rs index 904ce54..59e572a 100644 --- a/src/judge.rs +++ b/src/judge.rs @@ -1,6 +1,8 @@ use std::{env, io, marker::PhantomData, path::PathBuf, process::Stdio, time::Duration}; use bon::bon; +use byte_unit::Byte; +use futures_lite::{Stream, StreamExt}; use state_shift::{impl_state, type_state}; use tokio::{ fs, @@ -8,7 +10,7 @@ use tokio::{ }; use uuid::Uuid; -use crate::{Language, Metrics, Resource, Sandbox, Verdict}; +use crate::{AggregatedMetrics, Language, Metrics, Resource, Sandbox, Verdict}; const MAIN: &str = "main"; const CHECKER: &str = "checker"; @@ -203,4 +205,70 @@ impl Judge { memory_usage, }) } + + #[require(Compiled)] + pub async fn batch_run( + &self, + inputs: impl Iterator, + ) -> io::Result { + let mut verdict = Verdict::Accepted; + let mut total_run_time = Duration::ZERO; + let mut total_memory_usage = Byte::default(); + let mut count = 0; + + // running sequentially to enable early exit, saving resources + for input in inputs { + let metrics = self.run(input).await?; + total_run_time += metrics.run_time; + total_memory_usage = total_memory_usage + .add(metrics.memory_usage) + .expect("memory usage should not overflow u32"); + count += 1; + if metrics.verdict != Verdict::Accepted { + verdict = metrics.verdict; + break; + } + } + + Ok(AggregatedMetrics { + verdict, + average_run_time: total_run_time / count, + average_memory_usage: total_memory_usage + .divide(count as usize) + .expect("count must be greater than 0"), + }) + } + + #[require(Compiled)] + pub async fn streamed_batch_run( + &self, + mut inputs: impl Stream + std::marker::Unpin, + ) -> io::Result { + let mut verdict = Verdict::Accepted; + let mut total_run_time = Duration::ZERO; + let mut total_memory_usage = Byte::default(); + let mut count = 0; + + // running sequentially to enable early exit, saving resources + while let Some(input) = inputs.next().await { + let metrics = self.run(input).await?; + total_run_time += metrics.run_time; + total_memory_usage = total_memory_usage + .add(metrics.memory_usage) + .expect("memory usage should not overflow u32"); + count += 1; + if metrics.verdict != Verdict::Accepted { + verdict = metrics.verdict; + break; + } + } + + Ok(AggregatedMetrics { + verdict, + average_run_time: total_run_time / count, + average_memory_usage: total_memory_usage + .divide(count as usize) + .expect("count must be greater than 0"), + }) + } } diff --git a/src/metrics.rs b/src/metrics.rs index ce4b7a0..3689706 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -15,7 +15,6 @@ pub enum Verdict { } #[derive(Debug, Clone)] -#[cfg_attr(feature = "serde", derive(serde::Serialize))] pub struct Metrics { pub verdict: Verdict, pub run_time: Duration, @@ -23,3 +22,11 @@ pub struct Metrics { pub stdout: Vec, pub stderr: Vec, } + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize))] +pub struct AggregatedMetrics { + pub verdict: Verdict, + pub average_run_time: Duration, + pub average_memory_usage: Byte, +}