Skip to content

numpyro/NUTS fit_map path is significantly slower than emcee #34

@richteague

Description

@richteague

Summary

After the fit_map emcee speedups (f23cf29, 867bc38, 215a45e, cumulatively ~16×), the mcmc='numpyro' path is now wall-clock-slower than mcmc='emcee' on typical fits, even though it is dramatically more sample-efficient (256× fewer samples to reach the same posterior resolution on the validation fit in REFACTORING_PLAN.md §5.1b).

The two emcee speedups (JIT closure + vmap'd batch ln-prob) compile the per-step log-probability into a single XLA dispatch that vmaps across walkers. numpyro can't share that optimisation: NUTS extends its trajectory until a U-turn, which is a per-chain condition, so vmap'ing across chains doesn't help.

Profile of a current fit

9-parameter HD163296 3D fit (docs/tutorials/tutorial_6_numpyro.ipynb), 500 warmup + 500 samples, single chain:

  • Mean leapfrog steps per NUTS iteration: 127
  • Median: 95
  • ~39 % of iterations hit the max_tree_depth=8 cap (256 leapfrog steps)
  • Each gradient evaluation costs ~2× a likelihood evaluation

So one numpyro sample does the gradient work of ~250 emcee evaluations.

Micro-benchmark (50 warmup + 50 sample, post-JIT-warm)

config wall
1 chain, max_tree_depth=8 (previous tutorial default) 40.1 s
1 chain, max_tree_depth=6 (new tutorial default) 23.1 s
4 chains, chain_method='sequential' 197.3 s
4 chains, chain_method='vectorized' 194.3 s

Findings

  • Capping max_tree_depth at 6 gives a ~42 % wall-time speedup with minimal posterior loss — tutorial 6 now recommends 6.
  • chain_method='vectorized' does not help. NUTS' adaptive tree length is per-chain; vmap'd chains must all run to the longest tree on each iteration, cancelling the vmap win. (vectorized is only a win for HMC with fixed L.)

Untried directions

  • target_accept_prob < 0.8 default — accept noisier steps, smaller trajectories. Trade-off: more divergences.
  • dense_mass=True — adapt a full covariance preconditioner. Could help the strongly-correlated z0/psi/r_taper/q_taper block, but the warmup adaptation is slower.
  • More aggressive image downsampling — _make_model cost scales with pixel count; high-SNR pixels dominate the posterior.
  • Audit the numpyro JIT cache: confirm the model is compiled exactly once per fit and not invalidated between iterations.
  • See whether a bounded-horizon scan over leapfrog steps is feasible (probably blocked by the U-turn termination condition, but worth checking).

Workaround

Stay on mcmc='emcee' (the default) for routine fits. Use mcmc='numpyro' when sample efficiency is the actual bottleneck (very expensive single likelihood, very long autocorrelation under emcee, or GPU-available).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions