Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions src/underworld3/function/_function.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,16 @@ def global_evaluate_nd( expr,
Users should typically use :func:`underworld3.function.global_evaluate`
which provides automatic unit handling and a cleaner interface.

Contract: this is a faithful *parallel* counterpart of :func:`evaluate` —
a query point is interpolated wherever in the mesh it lands (on any rank),
a point just outside the mesh is extrapolated from its true nearest cell,
and ``check_extrapolated`` returns an inside/outside flag per point. The
Comment on lines +371 to +374
result is independent of the number of ranks (up to the rank-local
extrapolation residual near partition seams). Points that no rank can
locate in-cell are resolved by a best-claim reduction over ranks (see the
out-of-domain block below); set ``GE_LOCAL_FALLBACK=0`` to restore the
legacy behaviour where such points returned silently-wrong values.

Note it is not efficient to call this function to evaluate an expression at
a single coordinate. Instead the user should provide a numpy array of all
coordinates requiring evaluation.
Expand Down Expand Up @@ -520,6 +530,103 @@ def global_evaluate_nd( expr,
return_value[index, :, :] = data_container.array[:, :, :]
return_mask[index] = is_extrapolated.array[:]

# ------------------------------------------------------------------
# Out-of-domain extrapolation — keep the parallel result a faithful
# match for the serial ``evaluate()`` contract: interpolate a point
# wherever it lands across ranks, extrapolate a point just outside the
# mesh, and flag inside/outside.
Comment on lines +533 to +537
#
# After the migrate round-trip, a query point that NO rank could locate
# in one of its cells returns flagged-extrapolated but valued from
# whichever rank the bare dm.migrate happened to strand it on — typically
# a geometrically far, WRONG cell (the classic symptom is an annulus
# boundary point reading a value from the opposite side of the domain).
# Serial ``evaluate()`` instead extrapolates from the TRUE nearest cell.
# Restore that contract with a "best-claim" reduction over the (small,
# boundary-layer) stranded set:
#
# 1. allgather the extrapolated points so every rank holds the SAME
# global set;
# 2. each rank reports, per point, its nearest-local-cell distance and
# its LOCAL rbf extrapolation of the field there;
# 3. Allreduce(MIN distance) + Allreduce(MIN rank) tie-break picks the
# rank whose nearest cell is globally closest, and Allreduce(SUM of
# the winner-only value/flag) scatters that rank's extrapolation back.
#
# A point some rank actually contains (distance ~ 0) naturally wins, so
# only genuinely-stranded points are corrected. Cost is O(boundary points)
# — no dense global tree, no exhaustive search.
#
# DEADLOCK SAFETY — read before editing. Every collective here (allgather,
# Allreduce) runs unconditionally on the IDENTICAL global set on every
# rank, so all ranks stay in lockstep (n_ext_total is itself a reduced
# value, so the `> 0` guard is taken identically everywhere). The per-rank
# value MUST come from the LOCAL rbf path (rbf=True): the FE interpolation
# path (petsc_interpolate / DMInterpolation) is itself collective and would
# desync here, because each rank classifies the same global set against its
# own domain (different interior-point counts) → hang. Never route the
# fallback value through FE interpolation.
#
# Serial is left untouched (the serial path above already extrapolates from
# the true nearest cell). Escape hatch: GE_LOCAL_FALLBACK=0 restores the
# legacy (silently-wrong out-of-domain) behaviour; default on.
# ------------------------------------------------------------------
import os
if uw.mpi.size > 1 and os.environ.get("GE_LOCAL_FALLBACK", "1") not in (
"0", "off", "false", "no", ""):
Comment on lines +574 to +576
from mpi4py import MPI

comm = uw.mpi.comm
ext_idx = np.where(return_mask[:, 0, 0])[0]
ext_coords = np.ascontiguousarray(coords_array[ext_idx], dtype=np.double)

counts = np.array(comm.allgather(ext_coords.shape[0]), dtype=int)
n_ext_total = int(counts.sum())

if n_ext_total > 0:
parts = comm.allgather(ext_coords)
all_ext = np.concatenate(
[p for p in parts if p.size], axis=0).reshape(n_ext_total, -1)

# This rank's local rbf extrapolation of the global set. NON-collective
# value path — see DEADLOCK SAFETY above (must be rbf=True, never FE).
ext_vals, ext_flag = evaluate_nd(
expr, all_ext, rbf=True, evalf=False, verbose=False,
check_extrapolated=True,)
ext_vals = np.ascontiguousarray(
np.asarray(ext_vals, dtype=np.double).reshape((n_ext_total,) + expr_shape))
ext_flag = np.asarray(ext_flag).reshape(n_ext_total).astype(np.int32)

# Nearest-local-cell distance for every point (local kd-tree query).
mesh._build_kd_tree_index()
dist2, _ = mesh._centroid_index.query(all_ext, k=1, sqr_dists=True)
dist2 = np.ascontiguousarray(np.asarray(dist2, dtype=np.double).ravel())
Comment on lines +593 to +603

# Globally-nearest cell per point, lowest rank as the tie-break.
min_dist2 = np.empty(n_ext_total, dtype=np.double)
comm.Allreduce([dist2, MPI.DOUBLE], [min_dist2, MPI.DOUBLE], op=MPI.MIN)
my_claim = np.where(dist2 <= min_dist2 * (1.0 + 1e-12) + 1e-300,
comm.rank, comm.size).astype(np.int32)
win_rank = np.empty(n_ext_total, dtype=np.int32)
comm.Allreduce([my_claim, MPI.INT], [win_rank, MPI.INT], op=MPI.MIN)
i_win = (win_rank == comm.rank)

# Winner contributes value+flag, everyone else zero; SUM selects it.
contrib_val = np.ascontiguousarray(
np.where(i_win[:, None, None], ext_vals, 0.0))
best_val = np.empty_like(contrib_val)
comm.Allreduce([contrib_val, MPI.DOUBLE], [best_val, MPI.DOUBLE], op=MPI.SUM)
contrib_flag = np.where(i_win, ext_flag, 0).astype(np.int32)
best_flag = np.empty(n_ext_total, dtype=np.int32)
comm.Allreduce([contrib_flag, MPI.INT], [best_flag, MPI.INT], op=MPI.SUM)

# Scatter this rank's segment of the global set back to its points.
offset = int(counts[:comm.rank].sum())
seg = slice(offset, offset + ext_coords.shape[0])
if ext_idx.size:
return_value[ext_idx, :, :] = best_val[seg]
return_mask[ext_idx, 0, 0] = best_flag[seg].astype(bool)

if not check_extrapolated:
return return_value
else:
Expand Down
Loading