Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/underworld3/systems/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3165,7 +3165,7 @@ def delta_t(self, value):
self._delta_t.sym = value

@timing.routine_timer_decorator
def estimate_dt(self, direction_aware: bool = False):
def estimate_dt(self, direction_aware: bool = False, percentile: float = 0.0):
r"""
Estimate an appropriate timestep for the advection-diffusion solver.

Expand Down Expand Up @@ -3275,12 +3275,28 @@ def estimate_dt(self, direction_aware: bool = False):
## dt_adv_i = h_i / |v_i| for advection
## dt_diff_i = h_i^2 / κ for diffusion (using global κ for now)

# Reduce per-element dt to one global value. Default (percentile=0) =
# strict global MINIMUM — one cell sets the limit. percentile>0 takes the
# Nth global percentile (50 = median) of the per-element dt instead, so a
# few anisotropic SLIVER cells (velocity ACROSS a thin cell) don't collapse
# dt. SLCN is unconditionally stable, and ``direction_aware`` already
# credits cells stretched ALONG the flow — together they give the
# orientation-aware + sliver-robust timestep.
def _reduce_dt(per_elem):
fin = per_elem[np.isfinite(per_elem)] if len(per_elem) else per_elem
if percentile and percentile > 0:
gathered = comm.allgather(np.ascontiguousarray(fin, dtype=float))
Comment on lines +3278 to +3288
allv = (np.concatenate([a for a in gathered if a.size])
if any(a.size for a in gathered) else np.empty(0))
return float(np.percentile(allv, percentile)) if allv.size else np.inf
loc = float(np.min(fin)) if len(fin) else np.inf
return comm.allreduce(loc, op=MPI.MIN)

# Per-element diffusive timestep (all elements use same diffusivity)
if diffusivity_glob > 0:
dt_diff_per_element = (element_radii ** 2) / diffusivity_glob
min_dt_diff_local = np.min(dt_diff_per_element) if len(dt_diff_per_element) > 0 else np.inf
else:
min_dt_diff_local = np.inf
dt_diff_per_element = np.array([np.inf])

# Per-element advective timestep — either isotropic
# (mesh._radii / |v|) or direction-aware (v-aligned cell
Expand Down Expand Up @@ -3318,11 +3334,9 @@ def estimate_dt(self, direction_aware: bool = False):
h_per_element / vel_magnitudes,
np.inf
)
min_dt_adv_local = np.min(dt_adv_per_element) if len(dt_adv_per_element) > 0 else np.inf

# Get global minimum timesteps (parallel-safe)
min_dt_diff_glob = comm.allreduce(min_dt_diff_local, op=MPI.MIN)
min_dt_adv_glob = comm.allreduce(min_dt_adv_local, op=MPI.MIN)
# Global reduction — strict min (percentile=0) or Nth percentile (median).
min_dt_diff_glob = _reduce_dt(dt_diff_per_element)
min_dt_adv_glob = _reduce_dt(dt_adv_per_element)

# Store for user inspection
self.dt_adv = min_dt_adv_glob if not np.isinf(min_dt_adv_glob) else 0.0
Expand Down
Loading