Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 188 additions & 1 deletion sectionate/gridutils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
import numpy as np
import xarray as xr


def get_facedim(grid):
"""
Return the name of the `grid`'s face/tile dimension if it has `face_connections`
metadata (multi-tile grids such as the lat-lon-cap or cubed-sphere), else None.

Parameters
----------
grid: xgcm.Grid

Returns
-------
str or None
"""
return getattr(grid, "_facedim", None)


def get_geo_corners(grid):
"""
Find longitude and latitude coordinates from grid dataset, assuming the coordinate
Expand Down Expand Up @@ -96,4 +116,171 @@ def check_symmetric(grid):
elif pos_dict["right"]:
return False
else:
raise ValueError("Horizontal grid axes ('X', 'Y') must be either both symmetric or both non-symmetric (by MOM6 conventions).")
raise ValueError("Horizontal grid axes ('X', 'Y') must be either both symmetric or both non-symmetric (by MOM6 conventions).")


# ---------------------------------------------------------------------------
# Topology-aware neighbor maps
#
# The section pathfinder walks corner-to-corner across the grid. Rather than
# hard-coding how to step across periodic boundaries or the connections between
# the tiles of a multi-tile grid, we precompute, for every corner point, the
# (face, j, i) index of each of its four neighbors ("right", "left", "up",
# "down"). Walls (no neighbor, e.g. a "fill"/"extend" edge) are represented as
# the point itself, so the pathfinder simply never finds them closer to the
# target -- reproducing the previous clip-to-edge behavior.
#
# - Single-tile grids: neighbors follow from each axis' `boundary` metadata
# ("periodic" -> wrap, otherwise clip). Computed with plain numpy on the
# (already-trimmed) coordinate array the pathfinder walks.
# - Multi-tile grids (`face_connections`): edge-crossings require xgcm's
# topology logic (axis rotation + reversal across face seams). We reuse it
# directly by padding index-valued arrays with `xgcm` and reading the halos.
# ---------------------------------------------------------------------------

NEIGHBOR_DIRECTIONS = ("right", "left", "up", "down")


def simple_neighbor_maps(shape, boundary):
"""
Neighbor maps for a single-tile grid, derived from `boundary` metadata.

Parameters
----------
shape: tuple of int
(ny, nx) shape of the corner coordinate array the pathfinder walks.
boundary: dict
Maps grid axis ("X", "Y") to its xgcm boundary condition. "periodic"
wraps; anything else clips to the edge (so edge points are their own
neighbor across that boundary).

Returns
-------
dict
Maps each of `NEIGHBOR_DIRECTIONS` to (fmap, jmap, imap), where fmap is
None (no face dimension) and jmap/imap are int arrays of shape (ny, nx)
giving the neighbor's (j, i) for every point.
"""
ny, nx = shape
I = np.broadcast_to(np.arange(nx), (ny, nx))
J = np.broadcast_to(np.arange(ny)[:, None], (ny, nx))

def step(idx, n, b, delta):
if b == "periodic":
return np.mod(idx + delta, n)
return np.clip(idx + delta, 0, n - 1)

bx, by = boundary.get("X"), boundary.get("Y")
return {
"right": (None, J, step(I, nx, bx, +1)),
"left": (None, J, step(I, nx, bx, -1)),
"up": (None, step(J, ny, by, +1), I),
"down": (None, step(J, ny, by, -1), I),
}


def build_neighbor_maps(grid, geocorners):
"""
Neighbor maps for a multi-tile grid, derived from `face_connections`.

Builds index-valued DataArrays holding each corner point's own (face, j, i),
pads them by one cell with `xgcm` (which fills the halos using the grid's
face connections, applying any axis rotation and reversal), then reads the
halos to obtain each point's four neighbors. This delegates the intricate
cross-face index translation to xgcm's tested padding logic.

Parameters
----------
grid: xgcm.Grid
A grid with `face_connections` metadata (multi-tile).
geocorners: dict
Output of `get_geo_corners`; `geocorners["X"]` is the corner-position
longitude DataArray with dims (facedim, Y, X).

Returns
-------
dict
Maps each of `NEIGHBOR_DIRECTIONS` to (fmap, jmap, imap), int arrays of
shape (nf, ny, nx). Walls are represented as the point itself.
"""
from xgcm.padding import pad as _module_pad

facedim = grid._facedim
da = geocorners["X"]
Ydim, Xdim = da.dims[-2], da.dims[-1]
nf, ny, nx = da.sizes[facedim], da.sizes[Ydim], da.sizes[Xdim]

dims = (facedim, Ydim, Xdim)
iarr = xr.DataArray(np.broadcast_to(np.arange(nx), (nf, ny, nx)).astype(float), dims=dims)
jarr = xr.DataArray(np.broadcast_to(np.arange(ny)[:, None], (nf, ny, nx)).astype(float), dims=dims)
farr = xr.DataArray(np.broadcast_to(np.arange(nf)[:, None, None], (nf, ny, nx)).astype(float), dims=dims)

boundary = {ax: grid.axes[ax].boundary for ax in grid.axes}
boundary_width = {ax: (1, 1) for ax in grid.axes}

def pad(a):
# Prefer the public Grid.pad (xgcm >= 0.9); fall back to the module-level
# function in older releases that lack the method.
if hasattr(grid, "pad"):
return grid.pad(a, boundary_width=boundary_width, boundary=boundary, fill_value=np.nan)
return _module_pad(a, grid, boundary_width, boundary=boundary, fill_value=np.nan)

pf, pj, pi = pad(farr), pad(jarr), pad(iarr)

interior = slice(1, -1)
slices = {
"right": (interior, slice(2, None)),
"left": (interior, slice(0, -2)),
"up": (slice(2, None), interior),
"down": (slice(0, -2), interior),
}

own_f = np.broadcast_to(np.arange(nf)[:, None, None], (nf, ny, nx))
own_j = np.broadcast_to(np.arange(ny)[:, None], (nf, ny, nx))
own_i = np.broadcast_to(np.arange(nx), (nf, ny, nx))

maps = {}
for d, (ysl, xsl) in slices.items():
fmap = pf.isel({Ydim: ysl, Xdim: xsl}).values
jmap = pj.isel({Ydim: ysl, Xdim: xsl}).values
imap = pi.isel({Ydim: ysl, Xdim: xsl}).values
# A NaN halo means "no neighbor" (an unconnected/fill edge) -- represent
# the wall as the point itself, matching the single-tile clip behavior.
wall = np.isnan(fmap) | np.isnan(jmap) | np.isnan(imap)
fmap = np.where(wall, own_f, fmap).astype(np.int64)
jmap = np.where(wall, own_j, jmap).astype(np.int64)
imap = np.where(wall, own_i, imap).astype(np.int64)
maps[d] = (fmap, jmap, imap)

_validate_reciprocity(maps, own_f, own_j, own_i)
return maps


def _validate_reciprocity(maps, own_f, own_j, own_i):
"""
Verify that the neighbor maps describe a consistent topology: if B is a
(non-wall) neighbor of A, then A must be one of B's four neighbors.

xgcm's padding can mis-fill halos for complex reversed/rotated face
connections (its behavior is even hash-seed dependent for some
configurations, e.g. a full cubed sphere). Such a failure would otherwise
yield silently-wrong sections, so we detect it here and refuse rather than
return garbage neighbors.
"""
for d, (fmap, jmap, imap) in maps.items():
not_wall = ~((fmap == own_f) & (jmap == own_j) & (imap == own_i))
# Gather each neighbor's own four neighbors and look for the point back.
reciprocated = np.zeros(fmap.shape, dtype=bool)
for (f2, j2, i2) in maps.values():
bf = f2[fmap, jmap, imap]
bj = j2[fmap, jmap, imap]
bi = i2[fmap, jmap, imap]
reciprocated |= (bf == own_f) & (bj == own_j) & (bi == own_i)
if np.any(not_wall & ~reciprocated):
raise NotImplementedError(
"Could not derive a consistent neighbor topology from this grid's "
"`face_connections`. This is a known limitation of xgcm's padding for "
"complex rotated/reversed connections (e.g. a full cubed sphere or the "
"lat-lon-cap arctic cap). Sections on grids with simpler face connections "
"are supported."
)
Loading
Loading