diff --git a/sectionate/gridutils.py b/sectionate/gridutils.py index 19cccfc..2a6d376 100644 --- a/sectionate/gridutils.py +++ b/sectionate/gridutils.py @@ -1,3 +1,23 @@ +import numpy as np +import xarray as xr + + +def get_facedim(grid): + """ + Return the name of the `grid`'s face/tile dimension if it has `face_connections` + metadata (multi-tile grids such as the lat-lon-cap or cubed-sphere), else None. + + Parameters + ---------- + grid: xgcm.Grid + + Returns + ------- + str or None + """ + return getattr(grid, "_facedim", None) + + def get_geo_corners(grid): """ Find longitude and latitude coordinates from grid dataset, assuming the coordinate @@ -96,4 +116,171 @@ def check_symmetric(grid): elif pos_dict["right"]: return False else: - raise ValueError("Horizontal grid axes ('X', 'Y') must be either both symmetric or both non-symmetric (by MOM6 conventions).") \ No newline at end of file + raise ValueError("Horizontal grid axes ('X', 'Y') must be either both symmetric or both non-symmetric (by MOM6 conventions).") + + +# --------------------------------------------------------------------------- +# Topology-aware neighbor maps +# +# The section pathfinder walks corner-to-corner across the grid. Rather than +# hard-coding how to step across periodic boundaries or the connections between +# the tiles of a multi-tile grid, we precompute, for every corner point, the +# (face, j, i) index of each of its four neighbors ("right", "left", "up", +# "down"). Walls (no neighbor, e.g. a "fill"/"extend" edge) are represented as +# the point itself, so the pathfinder simply never finds them closer to the +# target -- reproducing the previous clip-to-edge behavior. +# +# - Single-tile grids: neighbors follow from each axis' `boundary` metadata +# ("periodic" -> wrap, otherwise clip). Computed with plain numpy on the +# (already-trimmed) coordinate array the pathfinder walks. +# - Multi-tile grids (`face_connections`): edge-crossings require xgcm's +# topology logic (axis rotation + reversal across face seams). We reuse it +# directly by padding index-valued arrays with `xgcm` and reading the halos. +# --------------------------------------------------------------------------- + +NEIGHBOR_DIRECTIONS = ("right", "left", "up", "down") + + +def simple_neighbor_maps(shape, boundary): + """ + Neighbor maps for a single-tile grid, derived from `boundary` metadata. + + Parameters + ---------- + shape: tuple of int + (ny, nx) shape of the corner coordinate array the pathfinder walks. + boundary: dict + Maps grid axis ("X", "Y") to its xgcm boundary condition. "periodic" + wraps; anything else clips to the edge (so edge points are their own + neighbor across that boundary). + + Returns + ------- + dict + Maps each of `NEIGHBOR_DIRECTIONS` to (fmap, jmap, imap), where fmap is + None (no face dimension) and jmap/imap are int arrays of shape (ny, nx) + giving the neighbor's (j, i) for every point. + """ + ny, nx = shape + I = np.broadcast_to(np.arange(nx), (ny, nx)) + J = np.broadcast_to(np.arange(ny)[:, None], (ny, nx)) + + def step(idx, n, b, delta): + if b == "periodic": + return np.mod(idx + delta, n) + return np.clip(idx + delta, 0, n - 1) + + bx, by = boundary.get("X"), boundary.get("Y") + return { + "right": (None, J, step(I, nx, bx, +1)), + "left": (None, J, step(I, nx, bx, -1)), + "up": (None, step(J, ny, by, +1), I), + "down": (None, step(J, ny, by, -1), I), + } + + +def build_neighbor_maps(grid, geocorners): + """ + Neighbor maps for a multi-tile grid, derived from `face_connections`. + + Builds index-valued DataArrays holding each corner point's own (face, j, i), + pads them by one cell with `xgcm` (which fills the halos using the grid's + face connections, applying any axis rotation and reversal), then reads the + halos to obtain each point's four neighbors. This delegates the intricate + cross-face index translation to xgcm's tested padding logic. + + Parameters + ---------- + grid: xgcm.Grid + A grid with `face_connections` metadata (multi-tile). + geocorners: dict + Output of `get_geo_corners`; `geocorners["X"]` is the corner-position + longitude DataArray with dims (facedim, Y, X). + + Returns + ------- + dict + Maps each of `NEIGHBOR_DIRECTIONS` to (fmap, jmap, imap), int arrays of + shape (nf, ny, nx). Walls are represented as the point itself. + """ + from xgcm.padding import pad as _module_pad + + facedim = grid._facedim + da = geocorners["X"] + Ydim, Xdim = da.dims[-2], da.dims[-1] + nf, ny, nx = da.sizes[facedim], da.sizes[Ydim], da.sizes[Xdim] + + dims = (facedim, Ydim, Xdim) + iarr = xr.DataArray(np.broadcast_to(np.arange(nx), (nf, ny, nx)).astype(float), dims=dims) + jarr = xr.DataArray(np.broadcast_to(np.arange(ny)[:, None], (nf, ny, nx)).astype(float), dims=dims) + farr = xr.DataArray(np.broadcast_to(np.arange(nf)[:, None, None], (nf, ny, nx)).astype(float), dims=dims) + + boundary = {ax: grid.axes[ax].boundary for ax in grid.axes} + boundary_width = {ax: (1, 1) for ax in grid.axes} + + def pad(a): + # Prefer the public Grid.pad (xgcm >= 0.9); fall back to the module-level + # function in older releases that lack the method. + if hasattr(grid, "pad"): + return grid.pad(a, boundary_width=boundary_width, boundary=boundary, fill_value=np.nan) + return _module_pad(a, grid, boundary_width, boundary=boundary, fill_value=np.nan) + + pf, pj, pi = pad(farr), pad(jarr), pad(iarr) + + interior = slice(1, -1) + slices = { + "right": (interior, slice(2, None)), + "left": (interior, slice(0, -2)), + "up": (slice(2, None), interior), + "down": (slice(0, -2), interior), + } + + own_f = np.broadcast_to(np.arange(nf)[:, None, None], (nf, ny, nx)) + own_j = np.broadcast_to(np.arange(ny)[:, None], (nf, ny, nx)) + own_i = np.broadcast_to(np.arange(nx), (nf, ny, nx)) + + maps = {} + for d, (ysl, xsl) in slices.items(): + fmap = pf.isel({Ydim: ysl, Xdim: xsl}).values + jmap = pj.isel({Ydim: ysl, Xdim: xsl}).values + imap = pi.isel({Ydim: ysl, Xdim: xsl}).values + # A NaN halo means "no neighbor" (an unconnected/fill edge) -- represent + # the wall as the point itself, matching the single-tile clip behavior. + wall = np.isnan(fmap) | np.isnan(jmap) | np.isnan(imap) + fmap = np.where(wall, own_f, fmap).astype(np.int64) + jmap = np.where(wall, own_j, jmap).astype(np.int64) + imap = np.where(wall, own_i, imap).astype(np.int64) + maps[d] = (fmap, jmap, imap) + + _validate_reciprocity(maps, own_f, own_j, own_i) + return maps + + +def _validate_reciprocity(maps, own_f, own_j, own_i): + """ + Verify that the neighbor maps describe a consistent topology: if B is a + (non-wall) neighbor of A, then A must be one of B's four neighbors. + + xgcm's padding can mis-fill halos for complex reversed/rotated face + connections (its behavior is even hash-seed dependent for some + configurations, e.g. a full cubed sphere). Such a failure would otherwise + yield silently-wrong sections, so we detect it here and refuse rather than + return garbage neighbors. + """ + for d, (fmap, jmap, imap) in maps.items(): + not_wall = ~((fmap == own_f) & (jmap == own_j) & (imap == own_i)) + # Gather each neighbor's own four neighbors and look for the point back. + reciprocated = np.zeros(fmap.shape, dtype=bool) + for (f2, j2, i2) in maps.values(): + bf = f2[fmap, jmap, imap] + bj = j2[fmap, jmap, imap] + bi = i2[fmap, jmap, imap] + reciprocated |= (bf == own_f) & (bj == own_j) & (bi == own_i) + if np.any(not_wall & ~reciprocated): + raise NotImplementedError( + "Could not derive a consistent neighbor topology from this grid's " + "`face_connections`. This is a known limitation of xgcm's padding for " + "complex rotated/reversed connections (e.g. a full cubed sphere or the " + "lat-lon-cap arctic cap). Sections on grids with simpler face connections " + "are supported." + ) \ No newline at end of file diff --git a/sectionate/section.py b/sectionate/section.py index df1433d..599511b 100644 --- a/sectionate/section.py +++ b/sectionate/section.py @@ -1,7 +1,13 @@ import numpy as np import xarray as xr -from .gridutils import get_geo_corners, check_symmetric +from .gridutils import ( + get_geo_corners, + check_symmetric, + get_facedim, + build_neighbor_maps, + simple_neighbor_maps, +) class Section(): """A named hydrographic section""" @@ -113,7 +119,7 @@ class GriddedSection(Section): ------- instance of GriddedSection """ - def __init__(self, section, grid, i_c=None, j_c=None): + def __init__(self, section, grid, i_c=None, j_c=None, f_c=None): super().__init__( section.name, section.coords, @@ -121,6 +127,7 @@ def __init__(self, section, grid, i_c=None, j_c=None): parent = section.parent ) self.grid = grid + self.f_c = f_c if isinstance(i_c, (list, np.ndarray)) & isinstance(j_c, (list, np.ndarray)): self.i_c = i_c self.j_c = j_c @@ -138,14 +145,19 @@ def grid_section(self, **kwargs): ----------------- **kwargs passed directly to sectionate.grid_section """ - self.i_c, self.j_c, self.lons_c, self.lats_c = grid_section( + out = grid_section( self.grid, self.lons_c, self.lats_c, **kwargs ) - - return self.i_c, self.j_c, self.lons_c, self.lats_c + if len(out) == 5: + self.i_c, self.j_c, self.f_c, self.lons_c, self.lats_c = out + else: + self.i_c, self.j_c, self.lons_c, self.lats_c = out + self.f_c = None + + return out def copy(self): """Creates a copy of a GriddedSection, with deep copies of all attributes except the grid.""" @@ -212,40 +224,77 @@ def join_sections(name, *sections, **kwargs): return section -def grid_section(grid, lons, lats, topology="latlon"): +def grid_section(grid, lons, lats): """ Compute composite section along model `grid` velocity faces that approximates geodesic paths between consecutive points defined by (lons, lats). + The grid topology is inferred entirely from the `grid` metadata: each axis' `boundary` + condition ("periodic" wraps, otherwise clip) for single-tile grids, and `face_connections` + for multi-tile grids (e.g. the lat-lon-cap or cubed-sphere). + Parameters ---------- grid: xgcm.Grid Object describing the geometry of the ocean model grid, including metadata about variable names for - the staggered C-grid dimensions and c oordinates. + the staggered C-grid dimensions and coordinates. lons: list or np.ndarray Longitudes, in degrees, of consecutive vertices defining a piece-wise geodesic section. lats: list or np.ndarray Latitudes, in degrees (in range [-90, 90]), of consecutive vertices defining a piece-wise geodesic section. - topology: str - Default: "latlon". Currently only supports the following options: ["latlon", "cartesian", "MOM-tripolar"]. - + Returns ------- - i_c, j_c, lons_c, lats_c: `np.ndarray` of types (int, int, float, float) - (i_c, j_c) correspond to indices of vorticity points that define velocity faces. + i_c, j_c[, f_c], lons_c, lats_c: `np.ndarray` + (i_c, j_c) correspond to indices of vorticity points that define velocity faces. For + multi-tile grids, the face index f_c of each point is returned as well. (lons_c, lats_c) are the corresponding longitude and latitudes. """ geocorners = get_geo_corners(grid) + boundary = {ax:grid.axes[ax].boundary for ax in grid.axes} + + facedim = get_facedim(grid) + if facedim is not None: + _check_supported_topology(grid) + neighbor_maps = build_neighbor_maps(grid, geocorners) + else: + neighbor_maps = None + return create_section_composite( geocorners["X"], geocorners["Y"], lons, lats, check_symmetric(grid), - boundary={ax:grid.axes[ax]._boundary for ax in grid.axes}, - topology=topology + boundary=boundary, + neighbor_maps=neighbor_maps, ) + +def _check_supported_topology(grid): + """ + Raise if the multi-tile `grid` requires topology features sectionate does not yet support. + + Specifically, the tripolar/bipolar north fold is represented as a face that connects to + itself (a self-connection) along the "Y" axis. xgcm does not support this either (see + https://github.com/xgcm/xgcm/issues/194), and sectionate's pathfinder cannot currently + cross such a seam, so we refuse it explicitly rather than silently produce a wrong path. + """ + facedim = grid._facedim + for axis in grid.axes: + connections = getattr(grid.axes[axis], "_connections", None) or {} + for face, sides in connections.items(): + for side in sides: + if side is None: + continue + neighbor_face = side[0] + if neighbor_face == face: + raise NotImplementedError( + "Grids with a face that connects to itself (e.g. the tripolar/bipolar " + "north fold) are not yet supported. Sections that do not cross the fold " + "work with a single-tile grid using boundary={'X':'periodic','Y':'extend'}." + ) + def create_section_composite( gridlon, gridlat, @@ -253,7 +302,7 @@ def create_section_composite( lats, symmetric, boundary={"X":"periodic", "Y":"extend"}, - topology="latlon" + neighbor_maps=None ): """ Compute composite section along velocity faces, as defined by coordinates of vorticity points (gridlon, gridlat), @@ -263,9 +312,10 @@ def create_section_composite( ----------- gridlon: np.ndarray - 2d array of longitude (with dimensions ("Y", "X")), in degrees + Array of longitude, in degrees. 2d (Y, X) for single-tile grids; 3d (face, Y, X) for + multi-tile grids (`face_connections`). gridlat: np.ndarray - 2d array of latitude (with dimensions ("Y", "X")), in degrees + Array of latitude, in degrees, with the same shape as `gridlon`. lons: list of float longitude of section starting, intermediate and end points, in degrees lats: list of float @@ -274,19 +324,24 @@ def create_section_composite( True if symmetric (vorticity on "outer" positions); False if non-symmetric (assuming "right" positions). boundary: dictionary mapping grid axis to boundary condition Default: {"X":"periodic", "Y":"extend"}. Set to {"X":"extend", "Y":"extend"} if using a non-periodic regional domain. - topology: str - Default: "latlon". Currently only supports the following options: ["latlon", "cartesian", "MOM-tripolar"]. + neighbor_maps: dict or None + Precomputed topology-aware neighbor maps for multi-tile grids (see + `sectionate.gridutils.build_neighbor_maps`); None for single-tile grids. RETURNS: ------- - i_c, j_c, lons_c, lats_c: `np.ndarray` of types (int, int, float, float) - (i_c, j_c) correspond to indices of vorticity points that define velocity faces. + i_c, j_c[, f_c], lons_c, lats_c: `np.ndarray` + (i_c, j_c[, f_c]) correspond to indices of vorticity points that define velocity faces; + the face index f_c is only returned for multi-tile grids. (lons_c, lats_c) are the corresponding longitude and latitudes. """ + multitile = neighbor_maps is not None + i_c = np.array([], dtype=np.int64) j_c = np.array([], dtype=np.int64) + f_c = np.array([], dtype=np.int64) lons_c = np.array([], dtype=np.float64) lats_c = np.array([], dtype=np.float64) @@ -294,7 +349,7 @@ def create_section_composite( raise ValueError("lons and lats should have the same length") for k in range(len(lons) - 1): - i_c_seg, j_c_seg, lons_c_seg, lats_c_seg = create_section( + seg = create_section( gridlon, gridlat, lons[k], @@ -303,22 +358,31 @@ def create_section_composite( lats[k + 1], symmetric, boundary=boundary, - topology=topology + neighbor_maps=neighbor_maps, ) + if multitile: + i_c_seg, j_c_seg, f_c_seg, lons_c_seg, lats_c_seg = seg + else: + i_c_seg, j_c_seg, lons_c_seg, lats_c_seg = seg i_c = np.concatenate([i_c, i_c_seg[:-1]], axis=0) j_c = np.concatenate([j_c, j_c_seg[:-1]], axis=0) lons_c = np.concatenate([lons_c, lons_c_seg[:-1]], axis=0) lats_c = np.concatenate([lats_c, lats_c_seg[:-1]], axis=0) - + if multitile: + f_c = np.concatenate([f_c, f_c_seg[:-1]], axis=0) + i_c = np.concatenate([i_c, [i_c_seg[-1]]], axis=0) j_c = np.concatenate([j_c, [j_c_seg[-1]]], axis=0) lons_c = np.concatenate([lons_c, [lons_c_seg[-1]]], axis=0) lats_c = np.concatenate([lats_c, [lats_c_seg[-1]]], axis=0) + if multitile: + f_c = np.concatenate([f_c, [f_c_seg[-1]]], axis=0) + return i_c.astype(np.int64), j_c.astype(np.int64), f_c.astype(np.int64), lons_c, lats_c return i_c.astype(np.int64), j_c.astype(np.int64), lons_c, lats_c -def create_section(gridlon, gridlat, lonstart, latstart, lonend, latend, symmetric, boundary={"X":"periodic", "Y":"extend"}, topology="latlon"): +def create_section(gridlon, gridlat, lonstart, latstart, lonend, latend, symmetric, boundary={"X":"periodic", "Y":"extend"}, neighbor_maps=None): """ Compute a section segment along velocity faces, as defined by coordinates of vorticity points (gridlon, gridlat), that most closely approximates the geodesic path between points (lonstart, latstart) and (lonend, latend). @@ -342,22 +406,27 @@ def create_section(gridlon, gridlat, lonstart, latstart, lonend, latend, symmetr True if symmetric (vorticity on "outer" positions); False if non-symmetric (assuming "right" positions). boundary: dictionary mapping grid axis to boundary condition Default: {"X":"periodic", "Y":"extend"}. Set to {"X":"extend", "Y":"extend"} if using a non-periodic regional domain. - topology: str - Default: "latlon". Currently only supports the following options: ["latlon", "cartesian", "MOM-tripolar"]. + neighbor_maps: dict or None + Precomputed topology-aware neighbor maps for multi-tile grids (see + `sectionate.gridutils.build_neighbor_maps`); None for single-tile grids. RETURNS: ------- - i_c, j_c, lons_c, lats_c: `np.ndarray` of types (int, int, float, float) - (i_c, j_c) correspond to indices of vorticity points that define velocity faces. + i_c, j_c[, f_c], lons_c, lats_c: `np.ndarray` + (i_c, j_c[, f_c]) correspond to indices of vorticity points that define velocity faces; + the face index f_c is only returned for multi-tile grids. (lons_c, lats_c) are the corresponding longitude and latitudes. """ - if symmetric and boundary["X"] == "periodic": + # Symmetric periodic single-tile grids carry a redundant final corner column + # (the periodic wrap of the first); drop it so periodicity is expressed purely + # by the modulo step. Multi-tile periodicity is handled by `face_connections`. + if symmetric and boundary["X"] == "periodic" and neighbor_maps is None: gridlon=gridlon[:,:-1] gridlat=gridlat[:,:-1] - i_c_seg, j_c_seg, lons_c_seg, lats_c_seg = infer_grid_path_from_geo( + return infer_grid_path_from_geo( lonstart, latstart, lonend, @@ -365,16 +434,10 @@ def create_section(gridlon, gridlat, lonstart, latstart, lonend, latend, symmetr gridlon, gridlat, boundary=boundary, - topology=topology - ) - return ( - i_c_seg, - j_c_seg, - lons_c_seg, - lats_c_seg + neighbor_maps=neighbor_maps, ) -def infer_grid_path_from_geo(lonstart, latstart, lonend, latend, gridlon, gridlat, boundary={"X":"periodic", "Y":"extend"}, topology="latlon"): +def infer_grid_path_from_geo(lonstart, latstart, lonend, latend, gridlon, gridlat, boundary={"X":"periodic", "Y":"extend"}, neighbor_maps=None): """ Find the grid indices (and coordinates) of vorticity points that most closely approximates the geodesic path between points (lonstart, latstart) and (lonend, latend). @@ -391,35 +454,35 @@ def infer_grid_path_from_geo(lonstart, latstart, lonend, latend, gridlon, gridla latend: float latitude of section end point, in degrees gridlon: np.ndarray - 2d array of longitude, in degrees + Array of longitude, in degrees. 2d (Y, X) for single-tile grids; 3d (face, Y, X) for + multi-tile grids (`face_connections`). gridlat: np.ndarray - 2d array of latitude, in degrees + Array of latitude, in degrees, with the same shape as `gridlon`. boundary: dictionary mapping grid axis to boundary condition Default: {"X":"periodic", "Y":"extend"}. Set to {"X":"extend", "Y":"extend"} if using a non-periodic regional domain. - topology: str - Default: "latlon". Currently only supports the following options: ["latlon", "cartesian", "MOM-tripolar"]. + neighbor_maps: dict or None + Precomputed topology-aware neighbor maps for multi-tile grids (see + `sectionate.gridutils.build_neighbor_maps`); None for single-tile grids. RETURNS: ------- - i_c, j_c, lons_c, lats_c: `np.ndarray` of types (int, int, float, float) - (i_c, j_c) correspond to indices of vorticity points that define velocity faces. + i_c, j_c[, f_c], lons_c, lats_c: `np.ndarray` + (i_c, j_c[, f_c]) correspond to indices of vorticity points that define velocity faces; + the face index f_c is only returned for multi-tile grids. (lons_c, lats_c) are the corresponding longitude and latitudes. """ - istart, jstart = find_closest_grid_point( - lonstart, - latstart, - gridlon, - gridlat - ) - iend, jend = find_closest_grid_point( - lonend, - latend, - gridlon, - gridlat - ) - i_c_seg, j_c_seg, lons_c_seg, lats_c_seg = infer_grid_path( + multitile = neighbor_maps is not None + if multitile: + istart, jstart, fstart = find_closest_grid_point(lonstart, latstart, gridlon, gridlat) + iend, jend, fend = find_closest_grid_point(lonend, latend, gridlon, gridlat) + else: + istart, jstart = find_closest_grid_point(lonstart, latstart, gridlon, gridlat) + iend, jend = find_closest_grid_point(lonend, latend, gridlon, gridlat) + fstart, fend = None, None + + return infer_grid_path( istart, jstart, iend, @@ -427,17 +490,16 @@ def infer_grid_path_from_geo(lonstart, latstart, lonend, latend, gridlon, gridla gridlon, gridlat, boundary=boundary, - topology=topology + neighbor_maps=neighbor_maps, + f1=fstart, + f2=fend, ) - return i_c_seg, j_c_seg, lons_c_seg, lats_c_seg - -def infer_grid_path(i1, j1, i2, j2, gridlon, gridlat, boundary={"X":"periodic", "Y":"extend"}, topology="latlon"): +def infer_grid_path(i1, j1, i2, j2, gridlon, gridlat, boundary={"X":"periodic", "Y":"extend"}, neighbor_maps=None, f1=None, f2=None): """ Find the grid indices (and coordinates) of vorticity points that most closely approximate - the geodesic path between points (gridlon[j1,i1], gridlat[j1,i1]) and - (gridlon[j2,i2], gridlat[j2,i2]). + the geodesic path between the starting and ending corner points. PARAMETERS: ----------- @@ -451,100 +513,114 @@ def infer_grid_path(i1, j1, i2, j2, gridlon, gridlat, boundary={"X":"periodic", j2: integer j-coord of point2 gridlon: np.ndarray - 2d array of longitude, in degrees + Array of longitude, in degrees. 2d (Y, X) for single-tile grids; 3d (face, Y, X) for + multi-tile grids (`face_connections`). gridlat: np.ndarray - 2d array of latitude, in degrees + Array of latitude, in degrees, with the same shape as `gridlon`. boundary: dictionary mapping grid axis to boundary condition Default: {"X":"periodic", "Y":"extend"}. Set to {"X":"extend", "Y":"extend"} if using a non-periodic regional domain. - topology: str - Default: "latlon". Currently only supports the following options: ["latlon", "cartesian", "MOM-tripolar"]. + Only used for single-tile grids (when `neighbor_maps is None`). + neighbor_maps: dict or None + Precomputed topology-aware neighbor maps (see `sectionate.gridutils.build_neighbor_maps`). + Required for multi-tile grids. If None, single-tile maps are built from `boundary`. + f1, f2: integer or None + Face indices of the starting and ending points (multi-tile grids only); None otherwise. RETURNS: ------- - i_c_seg, j_c_seg: list of int - list of (i,j) pairs bounded by (i1, j1) and (i2, j2) - lons_c_seg, lats_c_seg: list of float - corresponding longitude and latitude for i_c_seg, j_c_seg + For single-tile grids: + i_c_seg, j_c_seg, lons_c_seg, lats_c_seg + For multi-tile grids, additionally the face index of each point: + i_c_seg, j_c_seg, f_c_seg, lons_c_seg, lats_c_seg + + (i_c_seg, j_c_seg[, f_c_seg]) are the vorticity-point indices bounded by the start and end + points; (lons_c_seg, lats_c_seg) are the corresponding longitude and latitude. """ - ny, nx = gridlon.shape - if isinstance(gridlon, xr.core.dataarray.DataArray): gridlon = gridlon.values if isinstance(gridlat, xr.core.dataarray.DataArray): gridlat = gridlat.values + multitile = neighbor_maps is not None + if multitile: + nfaces, ny, nx = gridlon.shape + else: + # Single-tile grid: neighbors follow directly from `boundary` metadata. + # Build the lookup maps from the (already-trimmed) coordinate array so + # they stay consistent with the array we actually walk. + ny, nx = gridlon.shape + nfaces = 1 + neighbor_maps = simple_neighbor_maps((ny, nx), boundary) + + def coord(arr, f, j, i): + return arr[j, i] if f is None else arr[f, j, i] + + def neighbor(direction, f, j, i): + fmap, jmap, imap = neighbor_maps[direction] + if fmap is None: + return (None, int(jmap[j, i]), int(imap[j, i])) + return (int(fmap[f, j, i]), int(jmap[f, j, i]), int(imap[f, j, i])) + # target coordinates - lon1, lat1 = gridlon[j1, i1], gridlat[j1, i1] - lon2, lat2 = gridlon[j2, i2], gridlat[j2, i2] - - # init loop index to starting position - i = i1 - j = j1 + lon1, lat1 = coord(gridlon, f1, j1, i1), coord(gridlat, f1, j1, i1) + lon2, lat2 = coord(gridlon, f2, j2, i2), coord(gridlat, f2, j2, i2) + + # init loop position to starting point + f, j, i = f1, j1, i1 i_c_seg = [i] # add first point to list of points j_c_seg = [j] # add first point to list of points + f_c_seg = [f] # add first point to list of points # iterate through the grid path steps until we reach end of section ct = 0 # grid path step counter + # safety bound: enough steps to cross the whole grid (all faces) once + nstep_max = (nx + ny + 1) * nfaces # Grid-agnostic algorithm: - # First, find all four neighbors (subject to grid topology) + # First, find all four neighbors (using grid topology via `neighbor_maps`) # Second, throw away any that are further from the destination than the current point # Third, go to the valid neighbor that has the smallest angle from the arc path between the # start and end points (the shortest geodesic path) - j_prev, i_prev = j,i - while (i%nx != i2) or (j != j2): - + f_prev, j_prev, i_prev = f, j, i + while (f, j, i) != (f2, j2, i2): + # safety precaution: exit after taking enough steps to have crossed the entire model grid - if ct > (nx+ny+1): + if ct > nstep_max: raise RuntimeError(f"Should have reached the endpoint by now.") d_current = distance_on_unit_sphere( - gridlon[j,i], - gridlat[j,i], + coord(gridlon, f, j, i), + coord(gridlat, f, j, i), lon2, lat2 ) - + if d_current < 1.e-12: break - - if boundary["X"] == "periodic": - right = (j, (i+1)%nx) - left = (j, (i-1)%nx) - else: - right = (j, np.clip(i+1, 0, nx-1)) - left = (j, np.clip(i-1, 0, nx-1)) - down = (np.clip(j-1, 0, ny-1), i) - - if topology=="MOM-tripolar": - if j!=ny-1: - up = (j+1, i%nx) - else: - up = (j-1, (nx-1) - (i%nx)) - - elif topology=="cartesian" or topology=="latlon": - up = (np.clip(j+1, 0, ny-1), i) - else: - raise ValueError("Only 'cartesian', 'latlon', and 'MOM-tripolar' grid topologies are currently supported.") - - neighbors = [right, left, down, up] - j_next, i_next = None, None + neighbors = [ + neighbor("right", f, j, i), + neighbor("left", f, j, i), + neighbor("down", f, j, i), + neighbor("up", f, j, i), + ] + + next_pt = None smallest_angle = np.inf d_list = [] - for (_j, _i) in neighbors: + for (_f, _j, _i) in neighbors: d = distance_on_unit_sphere( - gridlon[_j,_i], - gridlat[_j,_i], + coord(gridlon, _f, _j, _i), + coord(gridlat, _f, _j, _i), lon2, lat2 ) d_list.append(d/d_current) if d < d_current: if d==0.: # We're done! - j_next, i_next = _j, _i + next_pt = (_f, _j, _i) smallest_angle = 0. break # Instead of simply moving to the point that gets us closest to the target, @@ -559,52 +635,57 @@ def infer_grid_path(i1, j1, i2, j2, gridlon, gridlat, boundary={"X":"periodic", lat2, lon1, lat1, - gridlon[_j,_i], - gridlat[_j,_i], + coord(gridlon, _f, _j, _i), + coord(gridlat, _f, _j, _i), ) angle2 = spherical_angle( lon1, lat1, lon2, lat2, - gridlon[_j,_i], - gridlat[_j,_i], + coord(gridlon, _f, _j, _i), + coord(gridlat, _f, _j, _i), ) angle = (angle1+angle2)/2. if angle < smallest_angle: - j_next, i_next = _j, _i + next_pt = (_f, _j, _i) smallest_angle = angle - + # There can be some strange edge cases in which none of the neighboring points # actually get us closer to the target (e.g. when closing folds in the grid). # In these cases, simply pick the adjacent point that gets us closest, as long as # it was not our previous point (to avoid endless loops). This algorithm should be # guaranteed to always get us to the target point. - if (smallest_angle == np.inf) or (j_next, i_next) == (j_prev, i_prev): - if (j_prev, i_prev) in neighbors: - idx = neighbors.index((j_prev, i_prev)) + if (smallest_angle == np.inf) or (next_pt == (f_prev, j_prev, i_prev)): + if (f_prev, j_prev, i_prev) in neighbors: + idx = neighbors.index((f_prev, j_prev, i_prev)) del neighbors[idx] del d_list[idx] - - (j_next, i_next) = neighbors[np.argmin(d_list)] - j_prev, i_prev = j,i - - j = j_next - i = i_next + next_pt = neighbors[int(np.argmin(d_list))] + + f_prev, j_prev, i_prev = f, j, i + + f, j, i = next_pt i_c_seg.append(i) j_c_seg.append(j) - + f_c_seg.append(f) + ct+=1 - # create lat/lon vectors from i,j pairs + # create lat/lon vectors from (f,j,i) triples lons_c_seg = [] lats_c_seg = [] - for jj, ji in zip(j_c_seg, i_c_seg): - lons_c_seg.append(gridlon[jj, ji]) - lats_c_seg.append(gridlat[jj, ji]) - return np.array(i_c_seg), np.array(j_c_seg), np.array(lons_c_seg), np.array(lats_c_seg) + for ff, jj, ji in zip(f_c_seg, j_c_seg, i_c_seg): + lons_c_seg.append(coord(gridlon, ff, jj, ji)) + lats_c_seg.append(coord(gridlat, ff, jj, ji)) + + i_c, j_c = np.array(i_c_seg), np.array(j_c_seg) + lons_c, lats_c = np.array(lons_c_seg), np.array(lats_c_seg) + if multitile: + return i_c, j_c, np.array(f_c_seg), lons_c, lats_c + return i_c, j_c, lons_c, lats_c def find_closest_grid_point(lon, lat, gridlon, gridlat): @@ -622,8 +703,10 @@ def find_closest_grid_point(lon, lat, gridlon, gridlat): RETURNS: -------- - iclose, jclose: integer - grid indices for geographical point of interest + For 2d (single-tile) grids: + iclose, jclose: integer grid indices for the geographical point of interest + For 3d (multi-tile) grids, additionally the face index: + iclose, jclose, fclose """ if isinstance(gridlon, xr.core.dataarray.DataArray): @@ -631,7 +714,11 @@ def find_closest_grid_point(lon, lat, gridlon, gridlat): if isinstance(gridlat, xr.core.dataarray.DataArray): gridlat = gridlat.values dist = distance_on_unit_sphere(lon, lat, gridlon, gridlat) - jclose, iclose = np.unravel_index(np.nanargmin(dist), gridlon.shape) + idx = np.unravel_index(np.nanargmin(dist), gridlon.shape) + if gridlon.ndim == 3: + fclose, jclose, iclose = idx + return iclose, jclose, fclose + jclose, iclose = idx return iclose, jclose def distance_on_unit_sphere(lon1, lat1, lon2, lat2, R=6.371e6, method="vincenty"): diff --git a/sectionate/tests/test_section_multitile.py b/sectionate/tests/test_section_multitile.py new file mode 100644 index 0000000..fcf6a92 --- /dev/null +++ b/sectionate/tests/test_section_multitile.py @@ -0,0 +1,460 @@ +"""Tests for sections on multi-tile grids defined by xgcm `face_connections` +(e.g. the lat-lon-cap and cubed-sphere grids).""" + +import numpy as np +import xarray as xr +import xgcm +import pytest + +from sectionate.gridutils import build_neighbor_maps, get_geo_corners, NEIGHBOR_DIRECTIONS +from sectionate.section import grid_section +from sectionate.transports import convergent_transport + + +def _make_grid(lon, lat, face_connections): + """Build a symmetric ('outer' corner) multi-tile grid from (face, yg, xg) coords.""" + nf, ng, _ = lon.shape + ds = xr.Dataset({}, coords={ + "xg": (("xg",), np.arange(ng)), + "yg": (("yg",), np.arange(ng)), + "face": (("face",), np.arange(nf)), + "geolon_c": (("face", "yg", "xg"), lon), + "geolat_c": (("face", "yg", "xg"), lat), + }) + return xgcm.Grid( + ds, + coords={"X": {"outer": "xg"}, "Y": {"outer": "yg"}}, + boundary="fill", fill_value=np.nan, + face_connections=face_connections, + autoparse_metadata=False, + ) + + +def two_face_x_to_x(Nc=6): + """Two faces side-by-side in longitude: face0 [0,90], face1 [90,180].""" + ng = Nc + 1 + lat1d = np.linspace(-45, 45, ng) + lon = np.zeros((2, ng, ng)); lat = np.zeros((2, ng, ng)) + lon[0] = np.broadcast_to(np.linspace(0, 90, ng), (ng, ng)) + lon[1] = np.broadcast_to(np.linspace(90, 180, ng), (ng, ng)) + lat[0] = np.broadcast_to(lat1d[:, None], (ng, ng)) + lat[1] = lat[0] + fc = {"face": {0: {"X": (None, (1, "X", False))}, + 1: {"X": ((0, "X", False), None)}}} + return _make_grid(lon, lat, fc) + + +def two_face_x_to_y(Nc=4): + """face0 right-X connects to face1 Y (a 90-degree rotation).""" + ng = Nc + 1 + lon = np.zeros((2, ng, ng)); lat = np.zeros((2, ng, ng)) + # Encode (face, i) so we can read topology back out; geometry is not used here. + for f in range(2): + for j in range(ng): + for i in range(ng): + lon[f, j, i] = f * 1000 + i + lat[f, j, i] = j + fc = {"face": {0: {"X": (None, (1, "Y", False))}, + 1: {"Y": ((0, "X", False), None)}}} + return _make_grid(lon, lat, fc) + + +def cubed_sphere(Nc=4): + ng = Nc + 1 + lon = np.zeros((6, ng, ng)); lat = np.zeros((6, ng, ng)) + fc = {"face": { + 0: {"X": ((3, "X", False), (1, "X", False)), "Y": ((4, "Y", False), (5, "Y", False))}, + 1: {"X": ((0, "X", False), (2, "X", False)), "Y": ((4, "X", False), (5, "X", True))}, + 2: {"X": ((1, "X", False), (3, "X", False)), "Y": ((4, "Y", True), (5, "Y", True))}, + 3: {"X": ((2, "X", False), (0, "X", False)), "Y": ((4, "X", True), (5, "X", False))}, + 4: {"X": ((3, "Y", True), (1, "Y", False)), "Y": ((2, "Y", True), (0, "Y", False))}, + 5: {"X": ((3, "Y", False), (1, "Y", True)), "Y": ((0, "Y", False), (2, "Y", True))}, + }} + return _make_grid(lon, lat, fc) + + +def is_neighbor(prev, cur, maps): + f, j, i = prev + for d in NEIGHBOR_DIRECTIONS: + fm, jm, im = maps[d] + if (int(fm[f, j, i]), int(jm[f, j, i]), int(im[f, j, i])) == cur: + return True + return False + + +def assert_path_invariant(i_c, j_c, f_c, maps): + """Every consecutive pair in the path must be a genuine grid neighbor.""" + for k in range(len(i_c) - 1): + prev = (int(f_c[k]), int(j_c[k]), int(i_c[k])) + cur = (int(f_c[k + 1]), int(j_c[k + 1]), int(i_c[k + 1])) + if prev == cur: + continue + assert is_neighbor(prev, cur, maps), f"{prev} -> {cur} is not a grid neighbor" + + +# --------------------------------------------------------------------------- +# Neighbor-map unit tests (topology only) +# --------------------------------------------------------------------------- + +def test_neighbor_maps_x_to_x_seam(): + grid = two_face_x_to_x() + maps = build_neighbor_maps(grid, get_geo_corners(grid)) + nx = maps["right"][0].shape[-1] + # Right edge of face 0 -> face 1, landing on its left edge (i=0), same row. + fmap, jmap, imap = maps["right"] + assert np.all(fmap[0][:, -1] == 1) + assert np.all(imap[0][:, -1] == 0) + assert np.all(jmap[0][:, -1] == np.arange(grid._ds.sizes["yg"])) + # Interior steps right by one on the same face. + assert np.all(fmap[0][:, :-1] == 0) + assert np.all(imap[0][:, :-1] == np.arange(1, nx)) + + +def test_neighbor_maps_x_to_y_rotation_and_reversal(): + grid = two_face_x_to_y() + maps = build_neighbor_maps(grid, get_geo_corners(grid)) + fmap, jmap, imap = maps["right"] + n = grid._ds.sizes["yg"] + # Right edge of face0 lands on face1... + assert np.all(fmap[0][:, -1] == 1) + # ...on its bottom row (Y=0, the X->Y rotation)... + assert np.all(jmap[0][:, -1] == 0) + # ...with the tangential index reversed (j -> n-1-j). + assert np.all(imap[0][:, -1] == (n - 1) - np.arange(n)) + + +def test_cubed_sphere_reciprocal_or_raises(): + # xgcm's padding is unreliable (hash-seed dependent) for the cubed sphere's + # complex reversed/rotated connections. build_neighbor_maps must therefore + # EITHER produce a fully reciprocal (correct) topology OR refuse with a clear + # error -- it must never silently return an inconsistent topology. This holds + # regardless of which (buggy) halo xgcm happens to produce. + grid = cubed_sphere() + try: + maps = build_neighbor_maps(grid, get_geo_corners(grid)) + except NotImplementedError: + return + # If it did not raise, the maps must be self-consistent: when face 0's right + # neighbor is face 1, the connection definition (X right -> 1) must hold. + assert np.all(maps["right"][0][0][:, -1] == 1) + assert np.all(maps["left"][0][0][:, 0] == 3) + assert np.all(maps["up"][0][0][-1, :] == 5) + assert np.all(maps["down"][0][0][0, :] == 4) + + +# --------------------------------------------------------------------------- +# End-to-end section tests +# --------------------------------------------------------------------------- + +def test_cross_face_section_returns_face_and_crosses_seam(): + grid = two_face_x_to_x() + out = grid_section(grid, [15., 165.], [0., 0.]) + assert len(out) == 5 # multi-tile grids return the face index too + i_c, j_c, f_c, lons_c, lats_c = out + assert set(np.unique(f_c).tolist()) == {0, 1} # the section spans both faces + assert lons_c[0] == pytest.approx(15., abs=10.) + assert lons_c[-1] == pytest.approx(165., abs=10.) + + +def test_cross_face_section_path_invariant(): + grid = two_face_x_to_x() + maps = build_neighbor_maps(grid, get_geo_corners(grid)) + i_c, j_c, f_c, lons_c, lats_c = grid_section(grid, [15., 165.], [0., 0.]) + assert_path_invariant(i_c, j_c, f_c, maps) + + +# --------------------------------------------------------------------------- +# Fold guard +# --------------------------------------------------------------------------- + +def _two_face_transport_grid(Nc=3): + """2-face (X-X) grid with center+corner coords and U/V transports for transport tests.""" + ng = Nc + 1 + yq = np.linspace(-45, 45, ng); yh = 0.5 * (yq[:-1] + yq[1:]) + lonq = [np.linspace(0, 90, ng), np.linspace(90, 180, ng)] + lonh = [0.5 * (l[:-1] + l[1:]) for l in lonq] + + LONc = np.stack([np.broadcast_to(lonq[f], (ng, ng)) for f in range(2)]) + LATc = np.stack([np.broadcast_to(yq[:, None], (ng, ng)) for f in range(2)]) + LON = np.stack([np.broadcast_to(lonh[f], (Nc, Nc)) for f in range(2)]) + LAT = np.stack([np.broadcast_to(yh[:, None], (Nc, Nc)) for f in range(2)]) + + rng = np.arange(2 * Nc * ng, dtype=float) + u = rng[: 2 * Nc * ng].reshape(2, Nc, ng) + v = (rng[: 2 * ng * Nc].reshape(2, ng, Nc)) * 0.5 + ds = xr.Dataset( + {"u": (("face", "yh", "xq"), u), "v": (("face", "yq", "xh"), v)}, + coords={ + "xq": (("xq",), np.arange(ng)), "yq": (("yq",), np.arange(ng)), + "xh": (("xh",), np.arange(Nc)), "yh": (("yh",), np.arange(Nc)), + "face": (("face",), [0, 1]), + "geolon_c": (("face", "yq", "xq"), LONc), "geolat_c": (("face", "yq", "xq"), LATc), + "geolon": (("face", "yh", "xh"), LON), "geolat": (("face", "yh", "xh"), LAT), + }, + ) + fc = {"face": {0: {"X": (None, (1, "X", False))}, + 1: {"X": ((0, "X", False), None)}}} + grid = xgcm.Grid(ds, coords={"X": {"outer": "xq", "center": "xh"}, + "Y": {"outer": "yq", "center": "yh"}}, + boundary="fill", fill_value=np.nan, + face_connections=fc, autoparse_metadata=False) + return grid + + +def _single_face_slab(grid, face): + """Standalone single-tile grid identical to one face of a multi-tile grid.""" + ds = grid._ds.isel({grid._facedim: face}).drop_vars(grid._facedim) + return xgcm.Grid(ds, coords={"X": {"outer": "xq", "center": "xh"}, + "Y": {"outer": "yq", "center": "yh"}}, + boundary={"X": "extend", "Y": "extend"}, autoparse_metadata=False) + + +def test_within_face_transport_matches_single_tile(): + grid = _two_face_transport_grid() + # A closed box wholly within face 0 (lon in [0, 90]). + lonseg = np.array([15., 75., 75., 15., 15.]) + latseg = np.array([-15., -15., 15., 15., -15.]) + i, j, f, lons, lats = grid_section(grid, lonseg, latseg) + assert np.all(f == 0) # stays on face 0 + + conv_mt = convergent_transport( + grid, i, j, f, utr="u", vtr="v", layer=None, geometry="cartesian" + )["conv_mass_transport"].sum().values + + # Same section + transports on the standalone face-0 grid must agree exactly. + slab = _single_face_slab(grid, 0) + i0, j0, lons0, lats0 = grid_section(slab, lonseg, latseg) + assert np.array_equal(i, i0) and np.array_equal(j, j0) + conv_st = convergent_transport( + slab, i0, j0, utr="u", vtr="v", layer=None, geometry="cartesian" + )["conv_mass_transport"].sum().values + + assert np.isclose(conv_mt, conv_st, rtol=1e-14) + + +def test_within_face_tracer_matches_single_tile(): + from sectionate.tracers import extract_tracer + grid = _two_face_transport_grid() + grid._ds["theta"] = grid._ds["u"].rename({"xq": "xh"}) + 7.0 # a (face, yh, xh) tracer + lonseg = np.array([15., 75., 75., 15., 15.]) + latseg = np.array([-15., -15., 15., 15., -15.]) + i, j, f, lons, lats = grid_section(grid, lonseg, latseg) + + tr_mt = extract_tracer("theta", grid, i, j, f_c=f).values + + slab = _single_face_slab(grid, 0) + slab._ds["theta"] = grid._ds["theta"].isel({grid._facedim: 0}).drop_vars(grid._facedim) + i0, j0, lons0, lats0 = grid_section(slab, lonseg, latseg) + tr_st = extract_tracer("theta", slab, i0, j0).values + + assert np.allclose(tr_mt, tr_st, equal_nan=True) + + +def _matched_single_and_split(Nh=4, Ny=4, seed=0): + """A single-tile wide grid and the identical grid cut into two X faces (x<->x glued). + + A seam-crossing section on the split grid must give the same transport as the same + section on the uncut grid -- the strongest oracle for seam attribution. + """ + rng = np.random.default_rng(seed) + nxh = 2 * Nh + xq = np.linspace(0., 80., nxh + 1); xh = 0.5 * (xq[:-1] + xq[1:]) + yq = np.linspace(-40., 40., Ny + 1); yh = 0.5 * (yq[:-1] + yq[1:]) + u = rng.standard_normal((Ny, nxh + 1)) # umo at (yh, xq) + v = rng.standard_normal((Ny + 1, nxh)) # vmo at (yq, xh) + + def grid2d(sl_q, sl_h, ds_extra): + ds = xr.Dataset(ds_extra, coords={ + "xq": (("xq",), np.arange(sl_q)), "yq": (("yq",), np.arange(Ny + 1)), + "xh": (("xh",), np.arange(sl_h)), "yh": (("yh",), np.arange(Ny)), + }) + return ds + + # single-tile + ds_s = xr.Dataset( + {"u": (("yh", "xq"), u), "v": (("yq", "xh"), v)}, + coords={ + "xq": (("xq",), np.arange(nxh + 1)), "yq": (("yq",), np.arange(Ny + 1)), + "xh": (("xh",), np.arange(nxh)), "yh": (("yh",), np.arange(Ny)), + "geolon_c": (("yq", "xq"), np.broadcast_to(xq, (Ny + 1, nxh + 1))), + "geolat_c": (("yq", "xq"), np.broadcast_to(yq[:, None], (Ny + 1, nxh + 1))), + "geolon": (("yh", "xh"), np.broadcast_to(xh, (Ny, nxh))), + "geolat": (("yh", "xh"), np.broadcast_to(yh[:, None], (Ny, nxh))), + }, + ) + single = xgcm.Grid(ds_s, coords={"X": {"outer": "xq", "center": "xh"}, + "Y": {"outer": "yq", "center": "yh"}}, + boundary={"X": "extend", "Y": "extend"}, autoparse_metadata=False) + + # two faces: cols [0:Nh] and [Nh:2Nh]; faces share the boundary corner column at Nh. + def face_slices(arr_q, arr_h): + return (np.stack([arr_q[..., 0:Nh + 1], arr_q[..., Nh:2 * Nh + 1]]), + np.stack([arr_h[..., 0:Nh], arr_h[..., Nh:2 * Nh]])) + LONc = np.stack([np.broadcast_to(xq[0:Nh + 1], (Ny + 1, Nh + 1)), + np.broadcast_to(xq[Nh:2 * Nh + 1], (Ny + 1, Nh + 1))]) + LATc = np.broadcast_to(yq[:, None], (Ny + 1, Nh + 1))[None].repeat(2, 0) + LON = np.stack([np.broadcast_to(xh[0:Nh], (Ny, Nh)), + np.broadcast_to(xh[Nh:2 * Nh], (Ny, Nh))]) + LAT = np.broadcast_to(yh[:, None], (Ny, Nh))[None].repeat(2, 0) + uf = np.stack([u[:, 0:Nh + 1], u[:, Nh:2 * Nh + 1]]) + vf = np.stack([v[:, 0:Nh], v[:, Nh:2 * Nh]]) + ds_f = xr.Dataset( + {"u": (("face", "yh", "xq"), uf), "v": (("face", "yq", "xh"), vf)}, + coords={ + "xq": (("xq",), np.arange(Nh + 1)), "yq": (("yq",), np.arange(Ny + 1)), + "xh": (("xh",), np.arange(Nh)), "yh": (("yh",), np.arange(Ny)), + "face": (("face",), [0, 1]), + "geolon_c": (("face", "yq", "xq"), LONc), "geolat_c": (("face", "yq", "xq"), LATc), + "geolon": (("face", "yh", "xh"), LON), "geolat": (("face", "yh", "xh"), LAT), + }, + ) + fc = {"face": {0: {"X": (None, (1, "X", False))}, 1: {"X": ((0, "X", False), None)}}} + split = xgcm.Grid(ds_f, coords={"X": {"outer": "xq", "center": "xh"}, + "Y": {"outer": "yq", "center": "yh"}}, + boundary="fill", fill_value=np.nan, + face_connections=fc, autoparse_metadata=False) + return single, split + + +def test_seam_crossing_transport_matches_uncut_grid(): + single, split = _matched_single_and_split() + # A closed box straddling the seam (lon 40). + lonseg = np.array([10., 70., 70., 10., 10.]) + latseg = np.array([-20., -20., 20., 20., -20.]) + + i, j, lons, lats = grid_section(single, lonseg, latseg) + conv_single = convergent_transport( + single, i, j, utr="u", vtr="v", layer=None, geometry="cartesian" + )["conv_mass_transport"].sum().values + + i2, j2, f2, lons2, lats2 = grid_section(split, lonseg, latseg) + assert set(np.unique(f2).tolist()) == {0, 1} # the section really crosses the seam + conv_split = convergent_transport( + split, i2, j2, f2, utr="u", vtr="v", layer=None, geometry="cartesian" + )["conv_mass_transport"].sum().values + + assert np.isclose(conv_single, conv_split, rtol=1e-12) + + +def _matched_single_and_split_nonsym(Nh=4, Ny=4, seed=0): + """Non-symmetric ('right' corner) analogue of `_matched_single_and_split`: a single-tile + grid and the identical grid cut into two X faces. Non-symmetric tilings do NOT share a + boundary vorticity column, so seam crossings are real (non-degenerate) velocity edges.""" + rng = np.random.default_rng(seed) + nxh = 2 * Nh + xh = np.linspace(5., 75., nxh); xq = xh + 5. # centers; 'right' edges + yh = np.linspace(-30., 30., Ny); yq = yh + 10. + u = rng.standard_normal((Ny, nxh)) # umo at (yh, xq) + v = rng.standard_normal((Ny, nxh)) # vmo at (yq, xh) + coords = {"X": {"right": "xq", "center": "xh"}, "Y": {"right": "yq", "center": "yh"}} + + ds_s = xr.Dataset( + {"u": (("yh", "xq"), u), "v": (("yq", "xh"), v)}, + coords={ + "xq": (("xq",), np.arange(nxh)), "yq": (("yq",), np.arange(Ny)), + "xh": (("xh",), np.arange(nxh)), "yh": (("yh",), np.arange(Ny)), + "geolon_c": (("yq", "xq"), np.broadcast_to(xq, (Ny, nxh))), + "geolat_c": (("yq", "xq"), np.broadcast_to(yq[:, None], (Ny, nxh))), + "geolon": (("yh", "xh"), np.broadcast_to(xh, (Ny, nxh))), + "geolat": (("yh", "xh"), np.broadcast_to(yh[:, None], (Ny, nxh))), + }) + single = xgcm.Grid(ds_s, coords=coords, boundary={"X": "extend", "Y": "extend"}, + autoparse_metadata=False) + + LONc = np.stack([np.broadcast_to(xq[0:Nh], (Ny, Nh)), np.broadcast_to(xq[Nh:2 * Nh], (Ny, Nh))]) + LATc = np.broadcast_to(yq[:, None], (Ny, Nh))[None].repeat(2, 0) + LON = np.stack([np.broadcast_to(xh[0:Nh], (Ny, Nh)), np.broadcast_to(xh[Nh:2 * Nh], (Ny, Nh))]) + LAT = np.broadcast_to(yh[:, None], (Ny, Nh))[None].repeat(2, 0) + ds_f = xr.Dataset( + {"u": (("face", "yh", "xq"), np.stack([u[:, 0:Nh], u[:, Nh:2 * Nh]])), + "v": (("face", "yq", "xh"), np.stack([v[:, 0:Nh], v[:, Nh:2 * Nh]]))}, + coords={ + "xq": (("xq",), np.arange(Nh)), "yq": (("yq",), np.arange(Ny)), + "xh": (("xh",), np.arange(Nh)), "yh": (("yh",), np.arange(Ny)), + "face": (("face",), [0, 1]), + "geolon_c": (("face", "yq", "xq"), LONc), "geolat_c": (("face", "yq", "xq"), LATc), + "geolon": (("face", "yh", "xh"), LON), "geolat": (("face", "yh", "xh"), LAT), + }) + fc = {"face": {0: {"X": (None, (1, "X", False))}, 1: {"X": ((0, "X", False), None)}}} + split = xgcm.Grid(ds_f, coords=coords, boundary="fill", fill_value=np.nan, + face_connections=fc, autoparse_metadata=False) + return single, split + + +def test_nonsymmetric_seam_transport_matches_uncut_grid(): + single, split = _matched_single_and_split_nonsym() + lonseg = np.array([15., 65., 65., 15., 15.]) + latseg = np.array([-15., -15., 15., 15., -15.]) + i, j, lons, lats = grid_section(single, lonseg, latseg) + conv_single = convergent_transport( + single, i, j, utr="u", vtr="v", layer=None, geometry="cartesian" + )["conv_mass_transport"].sum().values + i2, j2, f2, l2, la2 = grid_section(split, lonseg, latseg) + assert set(np.unique(f2).tolist()) == {0, 1} + conv_split = convergent_transport( + split, i2, j2, f2, utr="u", vtr="v", layer=None, geometry="cartesian" + )["conv_mass_transport"].sum().values + assert np.isclose(conv_single, conv_split, rtol=1e-12) + + +def _psi(lon, lat): + """A smooth streamfunction. For a non-divergent flow (umo,vmo = its discrete curl), the + transport across any section from P1 to P2 equals psi(P2) - psi(P1) -- a known, non-zero, + rotation-independent answer, so it pins the per-edge signs without any external data.""" + return 1.3 * lon - 0.7 * lat + 0.02 * lon * lat + + +def _uv_from_psi(P): + """umo = -dpsi/dy (at yh,xq), vmo = +dpsi/dx (at yq,xh), in the face's own frame.""" + return -(P[1:, :] - P[:-1, :]), (P[:, 1:] - P[:, :-1]) + + +def rotated_two_face_streamfunction(): + """Two faces meeting at a 90-degree ROTATION (face 1's local axes: +y->east, +x->south), + carrying a non-divergent streamfunction flow. face 0 is a normal lon[0,40] x lat[0,40] + square; face 1 covers lon[40,80] x lat[0,40] rotated. A proper (orientation-preserving) + rotation that attaches here needs no `reversed` flag on the connection.""" + a = np.arange(5) + f0_lonc = np.broadcast_to(a * 10., (5, 5)); f0_latc = np.broadcast_to((a * 10.)[:, None], (5, 5)) + f1_lonc = np.broadcast_to((40. + a * 10.)[:, None], (5, 5)); f1_latc = np.broadcast_to(40. - a * 10., (5, 5)) + LONc = np.stack([f0_lonc, f1_lonc]); LATc = np.stack([f0_latc, f1_latc]) + c = np.arange(4) + f0_lon = np.broadcast_to(5. + c * 10., (4, 4)); f0_lat = np.broadcast_to((5. + c * 10.)[:, None], (4, 4)) + f1_lon = np.broadcast_to((45. + c * 10.)[:, None], (4, 4)); f1_lat = np.broadcast_to(35. - c * 10., (4, 4)) + LON = np.stack([f0_lon, f1_lon]); LAT = np.stack([f0_lat, f1_lat]) + u0, v0 = _uv_from_psi(_psi(f0_lonc, f0_latc)) + u1, v1 = _uv_from_psi(_psi(f1_lonc, f1_latc)) + ds = xr.Dataset( + {"u": (("face", "yh", "xq"), np.stack([u0, u1])), "v": (("face", "yq", "xh"), np.stack([v0, v1]))}, + coords={ + "xq": (("xq",), np.arange(5)), "yq": (("yq",), np.arange(5)), + "xh": (("xh",), np.arange(4)), "yh": (("yh",), np.arange(4)), "face": (("face",), [0, 1]), + "geolon_c": (("face", "yq", "xq"), LONc), "geolat_c": (("face", "yq", "xq"), LATc), + "geolon": (("face", "yh", "xh"), LON), "geolat": (("face", "yh", "xh"), LAT), + }) + fc = {"face": {0: {"X": (None, (1, "Y", False))}, 1: {"Y": ((0, "X", False), None)}}} + return xgcm.Grid(ds, coords={"X": {"outer": "xq", "center": "xh"}, "Y": {"outer": "yq", "center": "yh"}}, + boundary="fill", fill_value=np.nan, face_connections=fc, autoparse_metadata=False) + + +def test_rotated_seam_transport_streamfunction(): + # Transport across a section crossing a ROTATED seam must equal the streamfunction + # difference between its endpoints -- the geographic per-edge sign makes this hold even + # though face 1's grid axes are rotated 90 degrees relative to face 0. + grid = rotated_two_face_streamfunction() + i, j, f, lons, lats = grid_section(grid, [10., 70.], [20., 20.]) + assert set(np.unique(f).tolist()) == {0, 1} # the section really crosses the rotated seam + conv = convergent_transport( + grid, i, j, f, utr="u", vtr="v", layer=None, geometry="cartesian" + )["conv_mass_transport"].sum().values + dpsi = _psi(lons[-1], lats[-1]) - _psi(lons[0], lats[0]) + assert np.isclose(abs(conv), abs(dpsi), rtol=1e-9) + + +def test_north_fold_self_connection_raises(): + ng = 5 + lon = np.zeros((1, ng, ng)); lat = np.zeros((1, ng, ng)) + fc = {"face": {0: {"Y": ((0, "Y", True), (0, "Y", True))}}} + grid = _make_grid(lon, lat, fc) + with pytest.raises(NotImplementedError): + grid_section(grid, [0., 1.], [0., 1.]) diff --git a/sectionate/tracers.py b/sectionate/tracers.py index 040f8f8..07243ec 100644 --- a/sectionate/tracers.py +++ b/sectionate/tracers.py @@ -6,7 +6,8 @@ from .gridutils import ( check_symmetric, - coord_dict + coord_dict, + get_facedim ) def extract_tracer( @@ -14,6 +15,7 @@ def extract_tracer( grid, i_c, j_c, + f_c=None, sect_coord="sect" ): """ @@ -42,26 +44,30 @@ def extract_tracer( da=grid._ds[name] coords = coord_dict(grid) symmetric = check_symmetric(grid) - + # get indices of UV points from broken line - uvindices = uvindices_from_qindices(grid, i_c, j_c) - + uvindices = uvindices_from_qindices(grid, i_c, j_c, f_c=f_c) + section = xr.Dataset() section["i"] = xr.DataArray(uvindices["i"], dims=sect_coord) section["j"] = xr.DataArray(uvindices["j"], dims=sect_coord) section["Umask"] = xr.DataArray(uvindices["var"]=="U", dims=sect_coord) section["Vmask"] = xr.DataArray(uvindices["var"]=="V", dims=sect_coord) + # On a multi-tile grid the velocity face varies per section point; select it pointwise. + facedim = get_facedim(grid) if f_c is not None else None + fsel = {facedim: xr.DataArray(uvindices["face"], dims=sect_coord)} if facedim is not None else {} + increment = 1 if symmetric else 0 usel_left = {coords["X"]["center"]: np.mod(section["i"]-increment , da[coords["X"]["center"]].size), - coords["Y"]["center"]: np.mod(section["j"] , da[coords["Y"]["center"]].size)} + coords["Y"]["center"]: np.mod(section["j"] , da[coords["Y"]["center"]].size), **fsel} usel_right = {coords["X"]["center"]: np.mod(section["i"]-increment+1, da[coords["X"]["center"]].size), - coords["Y"]["center"]: np.mod(section["j"] , da[coords["Y"]["center"]].size)} + coords["Y"]["center"]: np.mod(section["j"] , da[coords["Y"]["center"]].size), **fsel} vsel_left = {coords["X"]["center"]: np.mod(section["i"] , da[coords["X"]["center"]].size), - coords["Y"]["center"]: np.mod(section["j"]-increment , da[coords["Y"]["center"]].size)} + coords["Y"]["center"]: np.mod(section["j"]-increment , da[coords["Y"]["center"]].size), **fsel} vsel_right = {coords["X"]["center"]: np.mod(section["i"] , da[coords["X"]["center"]].size), - coords["Y"]["center"]: np.mod(section["j"]-increment+1, da[coords["Y"]["center"]].size)} + coords["Y"]["center"]: np.mod(section["j"]-increment+1, da[coords["Y"]["center"]].size), **fsel} tracer = sum([ xr.where(~np.isnan(da.isel(usel_right)), 0.5*da.isel(usel_left), da.isel(usel_left) ).fillna(0.) * section["Umask"], diff --git a/sectionate/transports.py b/sectionate/transports.py index 95523b1..8b0366a 100644 --- a/sectionate/transports.py +++ b/sectionate/transports.py @@ -3,10 +3,119 @@ import xarray as xr import dask -from .gridutils import check_symmetric, coord_dict, get_geo_corners +from .gridutils import ( + check_symmetric, coord_dict, get_geo_corners, get_facedim, build_neighbor_maps, + NEIGHBOR_DIRECTIONS, +) from .section import distance_on_unit_sphere -def uvindices_from_qindices(grid, i_c, j_c): + +def _edge_direction(A, neighbor_maps): + """Return a function mapping a neighbor point to the direction (right/left/up/down) + that reaches it from A=(f,j,i), or None if it is not a neighbor of A.""" + f, j, i = A + out = {} + for d in NEIGHBOR_DIRECTIONS: + fm, jm, im = neighbor_maps[d] + out[(int(fm[f, j, i]), int(jm[f, j, i]), int(im[f, j, i]))] = d + return out + + +# velocity (var, index-offset) for the section edge leaving corner (i,j) in each +# local direction, on a *symmetric* C-grid. V (vmo) lives at (X-center, Y-corner); +# U (umo) at (X-corner, Y-center). +_EDGE_VEL = { + "right": ("V", 0, 0), # vmo at (center i, corner j) + "left": ("V", -1, 0), # vmo at (center i-1, corner j) + "up": ("U", 0, 0), # umo at (corner i, center j) + "down": ("U", 0, -1), # umo at (corner i, center j-1) +} + + +def _in_velocity_range(var, vi, vj, ranges): + """Whether a velocity index is a real point on its face (vs. an off-face artifact).""" + if var == "V": # vmo at (X-center, Y-corner) + return (0 <= vi < ranges["Xc"]) and (0 <= vj < ranges["Yq"]) + return (0 <= vi < ranges["Xq"]) and (0 <= vj < ranges["Yc"]) # umo at (X-corner, Y-center) + + +def _anchor_velocity(d, f, j, i, symmetric): + """Staggered velocity index for the section edge leaving corner (f,j,i) in direction d, + read in face f's own frame. Returns (var, vi, vj).""" + var, di, dj = _EDGE_VEL[d] + vi, vj = i + di, j + dj + if not symmetric: # non-symmetric grids shift the staggered velocity index by one + vi, vj = (vi + 1, vj) if var == "V" else (vi, vj + 1) + return var, vi, vj + + +def _local_vec(lon0, lat0, lon1, lat1): + """Displacement (point0 -> point1) in a local flat (east, north) frame, in degrees, + with longitudes scaled by cos(lat) and wrapped across the dateline.""" + dlon = ((lon1 - lon0 + 180.0) % 360.0) - 180.0 + return dlon * np.cos(np.deg2rad(0.5 * (lat0 + lat1))), (lat1 - lat0) + + +def _left_sign(var, fv, jc, ic, A, B, glon, glat): + """ + +1 if the stored velocity's positive direction points to the LEFT of the section's + direction of travel (A -> B), else -1. Computed from geography, so it stays consistent + across faces however the grid is rotated underneath -- which is what makes seam-crossing + (including rotated) transports orient correctly without a single global flip. + + `var` is "U"/"V"; (fv, jc, ic) is the section corner on the velocity's own face, used to + read the velocity's positive (face +x for U, +y for V) direction from the corner positions. + """ + fA, jA, iA = A + fB, jB, iB = B + tx, ty = _local_vec(glon[fA, jA, iA], glat[fA, jA, iA], glon[fB, jB, iB], glat[fB, jB, iB]) + ny, nx = glon.shape[-2], glon.shape[-1] + if var == "U": # umo positive -> face +x + i1, i2 = (ic, ic + 1) if ic + 1 < nx else (ic - 1, ic) + vx, vy = _local_vec(glon[fv, jc, i1], glat[fv, jc, i1], glon[fv, jc, i2], glat[fv, jc, i2]) + else: # vmo positive -> face +y + j1, j2 = (jc, jc + 1) if jc + 1 < ny else (jc - 1, jc) + vx, vy = _local_vec(glon[fv, j1, ic], glat[fv, j1, ic], glon[fv, j2, ic], glat[fv, j2, ic]) + return 1 if (tx * vy - ty * vx) > 0 else -1 # cross(travel, vdir) > 0 <=> vdir is left + + +def _uv_for_edge(A, B, neighbor_maps, symmetric, ranges, glon, glat): + """ + Velocity face for the directed section edge from corner A=(fA,jA,iA) to B=(fB,jB,iB). + Returns (var, i, j, face, Lsign), where Lsign is +1 if the stored velocity's positive + direction points left of travel (a geographic sign -- see `_left_sign`); var is "0" for a + degenerate edge that carries no flux. + + The velocity index is read in a single face's frame, so no velocity is rotated across the + seam: the SOURCE face when the edge's normal velocity lives there (the usual case, and where + a rotated connection needs no rotation since the edge is a normal X/Y edge on that face); + otherwise the DESTINATION face (e.g. the trailing edge of a non-symmetric crossing); + otherwise the edge is degenerate (a crossing through a shared boundary corner of a symmetric + tiling). The sign is always geographic, so rotated seams orient correctly. + """ + fA, jA, iA = A + fB, jB, iB = B + d = _edge_direction(A, neighbor_maps)[B] + seam = fA != fB + + # 1. source-frame velocity (always valid within a face) + var_s, vi_s, vj_s = _anchor_velocity(d, fA, jA, iA, symmetric) + if not seam or _in_velocity_range(var_s, vi_s, vj_s, ranges): + Lsign = _left_sign(var_s, fA, jA, iA, A, B, glon, glat) + return var_s, int(vi_s), int(vj_s), int(fA), Lsign + + # 2. destination-frame velocity (the edge sits on B's d2-side) + d2 = _edge_direction(B, neighbor_maps)[A] # direction from B back to A + var_d, vi_d, vj_d = _anchor_velocity(d2, fB, jB, iB, symmetric) + if _in_velocity_range(var_d, vi_d, vj_d, ranges): + Lsign = _left_sign(var_d, fB, jB, iB, A, B, glon, glat) + return var_d, int(vi_d), int(vj_d), int(fB), Lsign + + # 3. degenerate crossing through a shared boundary corner -- carries no flux. + return "0", 0, 0, int(fB), 0 + + +def uvindices_from_qindices(grid, i_c, j_c, f_c=None): """ Find the `grid` indices of the N-1 velocity points defined by the consecutive indices of N vorticity points. Follows MOM6 conventions (https://mom6.readthedocs.io/en/main/api/generated/pages/Horizontal_Indexing.html), @@ -28,31 +137,62 @@ def uvindices_from_qindices(grid, i_c, j_c): - "var" : "U" if corresponding to "X"-direction velocity (usually nominally zonal), "V" otherwise - "i" : "X"-dimension index of appropriate "U" or "V" velocity - "j" : "Y"-dimension index of appropriate "U" or "V" velocity - - "nward" : True if point was passed through while going in positive "j"-index direction - - "eward" : True if point was passed through while going in positive "i"-index direction + - "Yinc" : True if point was passed through while going in positive "j"-index direction + - "Xinc" : True if point was passed through while going in positive "i"-index direction """ nsec = len(i_c) uvindices = { "var":np.zeros(nsec-1, dtype=" i_c[k] - nward = j_c[k+1] > j_c[k] + Xinc = i_c[k+1] > i_c[k] + Yinc = j_c[k+1] > j_c[k] # Handle corner cases for wrapping boundaries - if (i_c[k+1] - i_c[k])>1: eward = False - elif (i_c[k+1] - i_c[k])<-1: eward = True + if (i_c[k+1] - i_c[k])>1: Xinc = False + elif (i_c[k+1] - i_c[k])<-1: Xinc = True uvindex = { - "var": "V" if zonal else "U", - "i": i_c[k+(1 if not(eward) and zonal else 0)], - "j": j_c[k+(1 if not(nward) and not(zonal) else 0)], - "nward": nward, - "eward": eward, + "var": "V" if zonal else "U", + "i": i_c[k+(1 if not(Xinc) and zonal else 0)], + "j": j_c[k+(1 if not(Yinc) and not(zonal) else 0)], + "Yinc": Yinc, + "Xinc": Xinc, } uvindex["i"] += (1 if not(symmetric) and zonal else 0) uvindex["j"] += (1 if not(symmetric) and not(zonal) else 0) @@ -75,8 +215,8 @@ def uvcoords_from_uvindices(grid, uvindices): - "var" : "U" if corresponding to "X"-direction velocity (usually nominally zonal), "V" otherwise - "i" : "X"-dimension index of appropriate "U" or "V" velocity - "j" : "Y"-dimension index of appropriate "U" or "V" velocity - - "nward" : True if point was passed through while going in positive "j"-index direction - - "eward" : True if point was passed through while going in positive "i"-index direction + - "Yinc" : True if point was passed through while going in positive "j"-index direction + - "Xinc" : True if point was passed through while going in positive "i"-index direction RETURNS: -------- @@ -109,26 +249,39 @@ def uvcoords_from_uvindices(grid, uvindices): (coords["Y"]["corner"] in ds[c].coords)) if d in c}.items()} + facedim = get_facedim(grid) + faces = uvindices.get("face") + for p in range(len(uvindices["var"])): var, i, j = uvindices["var"][p], uvindices["i"][p], uvindices["j"][p] + if var not in ("U", "V"): + # Degenerate edge (e.g. a seam crossing through a shared corner): no point. + lons[p], lats[p] = np.nan, np.nan + continue + # On multi-tile grids, also select the velocity point's face. + fsel = {facedim: int(faces[p])} if (facedim is not None and faces is not None) else {} if var == "U": if (f"geolon_u" in u_names) and (f"geolat_u" in u_names): lon = ds[u_names[f"geolon_u"]].isel({ coords["X"]["corner"]:i, - coords["Y"]["center"]:j + coords["Y"]["center"]:j, + **fsel }).values lat = ds[u_names[f"geolat_u"]].isel({ coords["X"]["corner"]:i, - coords["Y"]["center"]:j + coords["Y"]["center"]:j, + **fsel }).values elif (f"geolon_corner" in corner_names) and (f"geolat_center" in center_names): lon = ds[corner_names[f"geolon_corner"]].isel({ coords["X"]["corner"]:i, - coords["Y"]["corner"]:j + coords["Y"]["corner"]:j, + **fsel }).values lat = ds[center_names[f"geolat_center"]].isel({ coords["X"]["center"]:wrap_idx(i, grid, "X"), - coords["Y"]["center"]:wrap_idx(j, grid, "Y") + coords["Y"]["center"]:wrap_idx(j, grid, "Y"), + **fsel }).values else: raise ValueError("Cannot locate grid coordinates necessary to\ @@ -137,20 +290,24 @@ def uvcoords_from_uvindices(grid, uvindices): if (f"geolon_v" in v_names) and (f"geolat_v" in v_names): lon = ds[v_names[f"geolon_v"]].isel({ coords["X"]["center"]:wrap_idx(i, grid, "X"), - coords["Y"]["corner"]:j + coords["Y"]["corner"]:j, + **fsel }).values lat = ds[v_names[f"geolat_v"]].isel({ coords["X"]["center"]:wrap_idx(i, grid, "X"), - coords["Y"]["corner"]:j + coords["Y"]["corner"]:j, + **fsel }).values elif (f"geolon_center" in center_names) and (f"geolat_corner" in corner_names): lon = ds[center_names[f"geolon_center"]].isel({ coords["X"]["center"]:wrap_idx(i, grid, "X"), - coords["Y"]["center"]:wrap_idx(j, grid, "Y") + coords["Y"]["center"]:wrap_idx(j, grid, "Y"), + **fsel }).values lat = ds[corner_names[f"geolat_corner"]].isel({ coords["X"]["corner"]:i, - coords["Y"]["corner"]:j + coords["Y"]["corner"]:j, + **fsel }).values else: raise ValueError("Cannot locate grid coordinates necessary to\ @@ -159,7 +316,7 @@ def uvcoords_from_uvindices(grid, uvindices): lats[p] = lat return lons, lats -def uvcoords_from_qindices(grid, i_c, j_c): +def uvcoords_from_qindices(grid, i_c, j_c, f_c=None): """ Directly finds coordinates of velocity points from vorticity point indices, wrapping other functions. @@ -168,9 +325,11 @@ def uvcoords_from_qindices(grid, i_c, j_c): grid: xgcm.Grid Grid object describing ocean model grid and containing data variables i_c: int - vorticity point indices along "X" dimension + vorticity point indices along "X" dimension j_c: int vorticity point indices along "Y" dimension + f_c: int or None + Face indices of the vorticity points for multi-tile grids; None for single-tile grids. RETURNS: -------- @@ -179,13 +338,14 @@ def uvcoords_from_qindices(grid, i_c, j_c): """ return uvcoords_from_uvindices( grid, - uvindices_from_qindices(grid, i_c, j_c), + uvindices_from_qindices(grid, i_c, j_c, f_c=f_c), ) def convergent_transport( grid, i_c, j_c, + f_c=None, utr="umo", vtr="vmo", layer="z_l", @@ -245,9 +405,13 @@ def convergent_transport( if layer.replace("l", "i") != interface: raise ValueError("Inconsistent layer and interface grid variables!") - uvindices = uvindices_from_qindices(grid, i_c, j_c) - uvcoords = uvcoords_from_qindices(grid, i_c, j_c) - + # On a multi-tile grid the contributing velocity face varies along the section + # (`uvindices["face"]`); it is selected pointwise in every `.isel` below. + facedim = get_facedim(grid) if f_c is not None else None + + uvindices = uvindices_from_qindices(grid, i_c, j_c, f_c=f_c) + uvcoords = uvcoords_from_qindices(grid, i_c, j_c, f_c=f_c) + sect = xr.Dataset() sect = sect.assign_coords({ sect_coord: xr.DataArray( @@ -257,30 +421,48 @@ def convergent_transport( }) sect["i"] = xr.DataArray(uvindices["i"], dims=sect_coord) sect["j"] = xr.DataArray(uvindices["j"], dims=sect_coord) + if facedim is not None: + sect["face"] = xr.DataArray(uvindices["face"], dims=sect_coord) + fsel = {facedim: sect["face"]} if facedim is not None else {} sect["Usign"] = xr.DataArray( - np.array([1 if i else -1 for i in ~uvindices["nward"]]), + np.array([1 if i else -1 for i in ~uvindices["Yinc"]]), dims=sect_coord ) sect["Vsign"] = xr.DataArray( - np.array([1 if i else -1 for i in uvindices["eward"]]), + np.array([1 if i else -1 for i in uvindices["Xinc"]]), dims=sect_coord ) sect["var"] = xr.DataArray(uvindices["var"], dims=sect_coord) sect["Umask"] = xr.DataArray(uvindices["var"]=="U", dims=sect_coord) sect["Vmask"] = xr.DataArray(uvindices["var"]=="V", dims=sect_coord) - + + # Per-edge sign: +1 if the velocity's positive direction points left of the section's + # direction of travel. Multi-tile grids carry this geometrically (`Lsign`), so it stays + # consistent across rotated faces; single-tile grids reduce to the original face-frame + # Usign/Vsign, leaving those results unchanged. + if facedim is not None: + sect["Lsign"] = xr.DataArray(uvindices["Lsign"], dims=sect_coord) + else: + sect["Lsign"] = sect["Usign"]*sect["Umask"] + sect["Vsign"]*sect["Vmask"] + mask_types = (np.ndarray, dask.array.Array, xr.DataArray) if isinstance(positive_in, mask_types): positive_in = is_mask_inside(positive_in, grid, sect) else: - if (geometry == "cartesian") and (grid.axes["X"]._boundary == "periodic"): + if (geometry == "cartesian") and (grid.axes["X"].boundary == "periodic"): raise ValueError("Periodic cartesian domains are not yet supported!") coords = coord_dict(grid) geo_corners = get_geo_corners(grid) + # corner-grid selection is over all N corners ("pt"); the face varies per corner. + corner_fsel = ( + {facedim: xr.DataArray(np.asarray(f_c), dims=("pt",))} + if facedim is not None else {} + ) idx = { coords["X"]["corner"]:xr.DataArray(i_c, dims=("pt",)), coords["Y"]["corner"]:xr.DataArray(j_c, dims=("pt",)), + **corner_fsel } counterclockwise = is_section_counterclockwise( geo_corners["X"].isel(idx).values, @@ -289,30 +471,26 @@ def convergent_transport( ) positive_in = positive_in ^ (not(counterclockwise)) orient_fact = 1 if positive_in else -1 - + coords = coord_dict(grid) usel = { coords["X"]["corner"]: sect["i"], - coords["Y"]["center"]: wrap_idx(sect["j"], grid, "Y") + coords["Y"]["center"]: wrap_idx(sect["j"], grid, "Y"), + **fsel } vsel = { coords["X"]["center"]: wrap_idx(sect["i"], grid, "X"), - coords["Y"]["corner"]: sect["j"] + coords["Y"]["corner"]: sect["j"], + **fsel } u = grid._ds[utr] v = grid._ds[vtr] - conv_umo_masked = ( - u.isel(usel).fillna(0.) - *sect["Usign"]*sect["Umask"] - ) - conv_vmo_masked = ( - v.isel(vsel).fillna(0.) - *sect["Vsign"]*sect["Vmask"] - ) + conv_umo_masked = u.isel(usel).fillna(0.)*sect["Umask"] + conv_vmo_masked = v.isel(vsel).fillna(0.)*sect["Vmask"] conv_transport = xr.DataArray( - (conv_umo_masked + conv_vmo_masked)*orient_fact, + (conv_umo_masked + conv_vmo_masked)*sect["Lsign"]*orient_fact, ) dsout = xr.Dataset({outname: conv_transport}) @@ -333,10 +511,7 @@ def convergent_transport( }) dsout = dsout.assign_coords({ - "sign": orient_fact*( - sect["Usign"]*sect["Umask"] + - sect["Vsign"]*sect["Vmask"] - ), + "sign": orient_fact*sect["Lsign"], "dir": xr.DataArray( np.array(["U" if u else "V" for u in sect["Umask"]]), coords=(dsout[sect_coord],), @@ -449,6 +624,8 @@ def is_mask_inside(mask, grid, sect, idx=0): """ symmetric = check_symmetric(grid) coords = coord_dict(grid) + facedim = get_facedim(grid) + fsel = {facedim: int(sect["face"][idx])} if (facedim is not None and "face" in sect) else {} if sect["var"][idx]=="U": i = ( sect["i"][idx] @@ -459,17 +636,20 @@ def is_mask_inside(mask, grid, sect, idx=0): if 0<=i<=grid._ds[coords["X"]["center"]].size-1: positive_in = mask.isel({ coords["X"]["center"]: i, - coords["Y"]["center"]: j + coords["Y"]["center"]: j, + **fsel }).values elif i==-1: positive_in = not(mask.isel({ coords["X"]["center"]: i+1, - coords["Y"]["center"]: j + coords["Y"]["center"]: j, + **fsel })).values elif i==grid._ds[coords["X"]["center"]].size: positive_in = not(mask.isel({ coords["X"]["center"]: i-1, - coords["Y"]["center"]: j + coords["Y"]["center"]: j, + **fsel })).values elif sect["var"][idx]=="V": i = sect["i"][idx] @@ -481,24 +661,27 @@ def is_mask_inside(mask, grid, sect, idx=0): if 0<=j<=grid._ds[coords["Y"]["center"]].size-1: positive_in = mask.isel({ coords["X"]["center"]: i, - coords["Y"]["center"]: j + coords["Y"]["center"]: j, + **fsel }).values elif j==-1: positive_in = not(mask.isel({ coords["X"]["center"]: i, coords["Y"]["center"]: j+1, + **fsel })).values elif j==grid._ds[coords["Y"]["center"]].size: positive_in = not(mask.isel({ coords["X"]["center"]: i, - coords["Y"]["center"]: j-1 + coords["Y"]["center"]: j-1, + **fsel })).values return positive_in def wrap_idx(idx, grid, axis): coords = coord_dict(grid) - if grid.axes[axis]._boundary == "periodic": + if grid.axes[axis].boundary == "periodic": idx = np.mod(idx, grid._ds[coords[axis]["center"]].size) else: idx = np.minimum(idx, grid._ds[coords[axis]["center"]].size-1)