This repository contains the code for the S-Aware Gating (SAG) experiments on MNIST from the paper:
ASkDAgger: Active Skill-level Data Aggregation for Interactive Imitation Learning
Jelle Luijkx, Zlatan Ajanović, Laura Ferranti, Jens Kober — TMLR 2025
[arXiv] [OpenReview] [Project page]
Related repository from the same project: askdagger_cliport for the language-conditioned manipulation experiments.
S-Aware Gating (SAG) dynamically adjusts the gating threshold
- Sensitivity mode — guarantees a desired true-positive rate on queries. Use when system failures (false negatives) are costly.
- Specificity mode — guarantees a desired true-negative rate. Use when unnecessary teacher queries (false positives) are costly.
- Success mode — guarantees a minimum overall system success rate. SAG issues more queries when the novice underperforms and fewer once the target is comfortably met.
In this framing, queries are positives and autonomous actions are negatives: a false positive is an unnecessary query (lost autonomy), and a false negative is an uncaught invalid novice action (system failure).
This repo reproduces the MNIST ablations that isolate and validate SAG, independent of the imitation-learning machinery in the main paper.
We advise using uv to install the dependencies of the askdagger_mnist package. Please install uv by following the installation instructions if you don't have it yet.
Clone the repository:
git clone https://github.com/askdagger/askdagger_mnist.git
cd askdagger_mnistCreate and activate the virtual environment:
uv venv --python 3.10
source .venv/bin/activateInstall the package:
uv pip install -e .A short run on a CUDA-enabled GPU:
python ./scripts/main.py --reps 2 --s_des 0.9Or on CPU:
python ./scripts/main.py --reps 2 --s_des 0.9 --accelerator cpuFull main experiments:
python ./scripts/main.pyAblations:
python ./scripts/ablations.pyAt every step, a batch of batch_size MNIST images is sampled. The LeNet novice performs inference and its uncertainty is quantified per sample. SAG sets the gating threshold p_rand. All queried samples are added to the training set, and the model is updated every update_every steps.
Uncertainty is estimated via Monte-Carlo dropout with a 40% dropout rate and 16 evaluations per input, giving an ensemble
Instead of retraining, you can download the results used in the paper:
python scripts/download_results.pyAfter training or downloading the results, generate the main figure:
python ./scripts/plot.pyThe resulting figure is saved as figures/mnist.pdf.
Generate the ablation plots:
python ./scripts/plot_reg_albation.py
python ./scripts/plot_prand_albation.pyIf you find this work useful, please consider citing:
@article{luijkx2025askdagger,
title = {{ASkDAgger}: Active Skill-level Data Aggregation for Interactive Imitation Learning},
author = {Luijkx, Jelle and Ajanovi{\'c}, Zlatan and Ferranti, Laura and Kober, Jens},
journal = {Transactions on Machine Learning Research (TMLR)},
year = {2025},
url = {https://openreview.net/forum?id=987Az9f8fT}
}TorchUncertainty — repo · Apache 2.0 Our main training script is adapted from this classification tutorial. The data modules are modified to allow interactive training on a growing subset of the MNIST dataset.