diff --git a/src/underworld3/function/_function.pyx b/src/underworld3/function/_function.pyx index 590dca34..0f7c2e1b 100644 --- a/src/underworld3/function/_function.pyx +++ b/src/underworld3/function/_function.pyx @@ -368,6 +368,16 @@ def global_evaluate_nd( expr, Users should typically use :func:`underworld3.function.global_evaluate` which provides automatic unit handling and a cleaner interface. + Contract: this is a faithful *parallel* counterpart of :func:`evaluate` — + a query point is interpolated wherever in the mesh it lands (on any rank), + a point just outside the mesh is extrapolated from its true nearest cell, + and ``check_extrapolated`` returns an inside/outside flag per point. The + result is independent of the number of ranks (up to the rank-local + extrapolation residual near partition seams). Points that no rank can + locate in-cell are resolved by a best-claim reduction over ranks (see the + out-of-domain block below); set ``GE_LOCAL_FALLBACK=0`` to restore the + legacy behaviour where such points returned silently-wrong values. + Note it is not efficient to call this function to evaluate an expression at a single coordinate. Instead the user should provide a numpy array of all coordinates requiring evaluation. @@ -520,6 +530,103 @@ def global_evaluate_nd( expr, return_value[index, :, :] = data_container.array[:, :, :] return_mask[index] = is_extrapolated.array[:] + # ------------------------------------------------------------------ + # Out-of-domain extrapolation — keep the parallel result a faithful + # match for the serial ``evaluate()`` contract: interpolate a point + # wherever it lands across ranks, extrapolate a point just outside the + # mesh, and flag inside/outside. + # + # After the migrate round-trip, a query point that NO rank could locate + # in one of its cells returns flagged-extrapolated but valued from + # whichever rank the bare dm.migrate happened to strand it on — typically + # a geometrically far, WRONG cell (the classic symptom is an annulus + # boundary point reading a value from the opposite side of the domain). + # Serial ``evaluate()`` instead extrapolates from the TRUE nearest cell. + # Restore that contract with a "best-claim" reduction over the (small, + # boundary-layer) stranded set: + # + # 1. allgather the extrapolated points so every rank holds the SAME + # global set; + # 2. each rank reports, per point, its nearest-local-cell distance and + # its LOCAL rbf extrapolation of the field there; + # 3. Allreduce(MIN distance) + Allreduce(MIN rank) tie-break picks the + # rank whose nearest cell is globally closest, and Allreduce(SUM of + # the winner-only value/flag) scatters that rank's extrapolation back. + # + # A point some rank actually contains (distance ~ 0) naturally wins, so + # only genuinely-stranded points are corrected. Cost is O(boundary points) + # — no dense global tree, no exhaustive search. + # + # DEADLOCK SAFETY — read before editing. Every collective here (allgather, + # Allreduce) runs unconditionally on the IDENTICAL global set on every + # rank, so all ranks stay in lockstep (n_ext_total is itself a reduced + # value, so the `> 0` guard is taken identically everywhere). The per-rank + # value MUST come from the LOCAL rbf path (rbf=True): the FE interpolation + # path (petsc_interpolate / DMInterpolation) is itself collective and would + # desync here, because each rank classifies the same global set against its + # own domain (different interior-point counts) → hang. Never route the + # fallback value through FE interpolation. + # + # Serial is left untouched (the serial path above already extrapolates from + # the true nearest cell). Escape hatch: GE_LOCAL_FALLBACK=0 restores the + # legacy (silently-wrong out-of-domain) behaviour; default on. + # ------------------------------------------------------------------ + import os + if uw.mpi.size > 1 and os.environ.get("GE_LOCAL_FALLBACK", "1") not in ( + "0", "off", "false", "no", ""): + from mpi4py import MPI + + comm = uw.mpi.comm + ext_idx = np.where(return_mask[:, 0, 0])[0] + ext_coords = np.ascontiguousarray(coords_array[ext_idx], dtype=np.double) + + counts = np.array(comm.allgather(ext_coords.shape[0]), dtype=int) + n_ext_total = int(counts.sum()) + + if n_ext_total > 0: + parts = comm.allgather(ext_coords) + all_ext = np.concatenate( + [p for p in parts if p.size], axis=0).reshape(n_ext_total, -1) + + # This rank's local rbf extrapolation of the global set. NON-collective + # value path — see DEADLOCK SAFETY above (must be rbf=True, never FE). + ext_vals, ext_flag = evaluate_nd( + expr, all_ext, rbf=True, evalf=False, verbose=False, + check_extrapolated=True,) + ext_vals = np.ascontiguousarray( + np.asarray(ext_vals, dtype=np.double).reshape((n_ext_total,) + expr_shape)) + ext_flag = np.asarray(ext_flag).reshape(n_ext_total).astype(np.int32) + + # Nearest-local-cell distance for every point (local kd-tree query). + mesh._build_kd_tree_index() + dist2, _ = mesh._centroid_index.query(all_ext, k=1, sqr_dists=True) + dist2 = np.ascontiguousarray(np.asarray(dist2, dtype=np.double).ravel()) + + # Globally-nearest cell per point, lowest rank as the tie-break. + min_dist2 = np.empty(n_ext_total, dtype=np.double) + comm.Allreduce([dist2, MPI.DOUBLE], [min_dist2, MPI.DOUBLE], op=MPI.MIN) + my_claim = np.where(dist2 <= min_dist2 * (1.0 + 1e-12) + 1e-300, + comm.rank, comm.size).astype(np.int32) + win_rank = np.empty(n_ext_total, dtype=np.int32) + comm.Allreduce([my_claim, MPI.INT], [win_rank, MPI.INT], op=MPI.MIN) + i_win = (win_rank == comm.rank) + + # Winner contributes value+flag, everyone else zero; SUM selects it. + contrib_val = np.ascontiguousarray( + np.where(i_win[:, None, None], ext_vals, 0.0)) + best_val = np.empty_like(contrib_val) + comm.Allreduce([contrib_val, MPI.DOUBLE], [best_val, MPI.DOUBLE], op=MPI.SUM) + contrib_flag = np.where(i_win, ext_flag, 0).astype(np.int32) + best_flag = np.empty(n_ext_total, dtype=np.int32) + comm.Allreduce([contrib_flag, MPI.INT], [best_flag, MPI.INT], op=MPI.SUM) + + # Scatter this rank's segment of the global set back to its points. + offset = int(counts[:comm.rank].sum()) + seg = slice(offset, offset + ext_coords.shape[0]) + if ext_idx.size: + return_value[ext_idx, :, :] = best_val[seg] + return_mask[ext_idx, 0, 0] = best_flag[seg].astype(bool) + if not check_extrapolated: return return_value else: