Skip to content

askdagger/askdagger_mnist

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ASkDAgger (MNIST): S-Aware Gating on MNIST

arXiv OpenReview Project Page TMLR 2025 Python 3.10

This repository contains the code for the S-Aware Gating (SAG) experiments on MNIST from the paper:

ASkDAgger: Active Skill-level Data Aggregation for Interactive Imitation Learning
Jelle Luijkx, Zlatan Ajanović, Laura Ferranti, Jens Kober — TMLR 2025
[arXiv] [OpenReview] [Project page]

Related repository from the same project: askdagger_cliport for the language-conditioned manipulation experiments.

Overview

S-Aware Gating (SAG) dynamically adjusts the gating threshold $\gamma$ in an interactive-learning setting so the system tracks a user-specified target. Three modes are supported:

  • Sensitivity mode — guarantees a desired true-positive rate on queries. Use when system failures (false negatives) are costly.
  • Specificity mode — guarantees a desired true-negative rate. Use when unnecessary teacher queries (false positives) are costly.
  • Success mode — guarantees a minimum overall system success rate. SAG issues more queries when the novice underperforms and fewer once the target is comfortably met.

In this framing, queries are positives and autonomous actions are negatives: a false positive is an unnecessary query (lost autonomy), and a false negative is an uncaught invalid novice action (system failure).

This repo reproduces the MNIST ablations that isolate and validate SAG, independent of the imitation-learning machinery in the main paper.

Installation

Prerequisites: install uv

We advise using uv to install the dependencies of the askdagger_mnist package. Please install uv by following the installation instructions if you don't have it yet.

Install askdagger_mnist

Clone the repository:

git clone https://github.com/askdagger/askdagger_mnist.git
cd askdagger_mnist

Create and activate the virtual environment:

uv venv --python 3.10
source .venv/bin/activate

Install the package:

uv pip install -e .

Training

Quick start

A short run on a CUDA-enabled GPU:

python ./scripts/main.py --reps 2 --s_des 0.9

Or on CPU:

python ./scripts/main.py --reps 2 --s_des 0.9 --accelerator cpu

Reproduce the paper

Full main experiments:

python ./scripts/main.py

Ablations:

python ./scripts/ablations.py

How interactive training works

At every step, a batch of batch_size MNIST images is sampled. The LeNet novice performs inference and its uncertainty is quantified per sample. SAG sets the gating threshold $\gamma$ from the target metric. For every sample with uncertainty above $\gamma$, a ground-truth label is queried; for samples below $\gamma$, a label is still queried with probability p_rand. All queried samples are added to the training set, and the model is updated every update_every steps.

Uncertainty quantification

Uncertainty is estimated via Monte-Carlo dropout with a 40% dropout rate and 16 evaluations per input, giving an ensemble $\mathcal{C} = {h_1, \dots, h_C}$. For an input $x$ with label $y$:

$$u = 1 - \max_y P_\mathcal{C}(y \mid x), \quad P_\mathcal{C}(y \mid x) = \frac{1}{C} \sum_{i=1}^{C} P_i(y \mid x).$$

Download results

Instead of retraining, you can download the results used in the paper:

python scripts/download_results.py

Reproduce the plots

After training or downloading the results, generate the main figure:

python ./scripts/plot.py

The resulting figure is saved as figures/mnist.pdf.

Generate the ablation plots:

python ./scripts/plot_reg_albation.py
python ./scripts/plot_prand_albation.py

Citation

If you find this work useful, please consider citing:

@article{luijkx2025askdagger,
  title   = {{ASkDAgger}: Active Skill-level Data Aggregation for Interactive Imitation Learning},
  author  = {Luijkx, Jelle and Ajanovi{\'c}, Zlatan and Ferranti, Laura and Kober, Jens},
  journal = {Transactions on Machine Learning Research (TMLR)},
  year    = {2025},
  url     = {https://openreview.net/forum?id=987Az9f8fT}
}

Acknowledgements

TorchUncertaintyrepo · Apache 2.0 Our main training script is adapted from this classification tutorial. The data modules are modified to allow interactive training on a growing subset of the MNIST dataset.

About

Official code MNIST experiments for ASkDAgger (TMLR 2025): Active Skill-level Data Aggregation for Interactive Imitation Learning.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages