Source code for ants.regrid.esmf

# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of ANTS and is released under the BSD 3-Clause license.
# See LICENSE.txt in the root of the repository for full licensing details.
import functools
import hashlib
import operator
import os
import tempfile
import warnings

import ants
import cartopy.crs as ccrs
import iris
import numpy as np
from scipy.stats import rankdata
from shapely.geometry import LinearRing

ESMPY_IMPORT_MESSAGE = """To use ESMF, set the ESMFMKFILE environment variable.
https://earthsystemmodeling.org/esmpy_doc/release/latest/html/install.html
#importing-esmpy"""

try:
    import esmpy

    _ESMPY_IMPORT_ERROR = False
except Exception as _ESMPY_IMPORT_ERROR:
    if "ESMFMKFILE" not in os.environ:
        warnings.warn(ESMPY_IMPORT_MESSAGE)
    if os.environ.get("ESMFMKFILE") == "":
        warnings.warn(ESMPY_IMPORT_MESSAGE)
    esmpy = None
    msg = " {}\nProceeding without capabilities provided by ESMPy (esmf)."
    warnings.warn(msg.format(str(_ESMPY_IMPORT_ERROR)))


def _source_cube_sanity_check(src_cube):
    sx = src_cube.coord(axis="x")
    sy = src_cube.coord(axis="y")
    sgrid_dims = set(src_cube.coord_dims(sx) + src_cube.coord_dims(sy))

    coords_sharing = [
        src_cube.coords(contains_dimension=sgrid_dim) for sgrid_dim in sgrid_dims
    ]
    for coords in coords_sharing:
        for coord in coords:
            if coord not in [sx, sy]:
                msg = "Additional coordinate(s) vary along the " "horizontal mapping."
                raise ValueError(msg)


def _supported_cube_check(cube):
    x_dims = cube.coord_dims(cube.coord(axis="x"))
    y_dims = cube.coord_dims(cube.coord(axis="y"))
    msg = "Currently only increasing rank dimension mappings are supported."
    if len(x_dims) == len(y_dims):
        if len(x_dims) == 1:
            if not np.allclose(
                (rankdata(y_dims + x_dims, method="ordinal") - 1), np.arange(2)
            ):
                msg += "  For 1D cases, expecting y, x dimension mapping " "(not x, y)."
                raise RuntimeError(msg)
        else:
            for dims in [x_dims, y_dims]:
                if not np.allclose(
                    (rankdata(dims, method="ordinal") - 1), np.arange(len(dims))
                ):
                    raise RuntimeError(msg)


def _remove_undesirable_attributes(cube):
    # Temporary workaround for the valid_x attributes which should likely not
    # persist a regrid operation.
    rm_attributes = ["valid_range", "valid_min", "valid_max"]
    [
        cube.attributes.pop(key) if key in rm_attributes else None
        for key in list(cube.attributes.keys())
    ]


class _LatLonExtractor(object):
    """
    Adaptor class that takes a cube and extracts the true cell latitudes and
    longitudes.

    """

    def __init__(self, cube, staggering="corner"):
        """
        Constructor

        Parameter
        ----------
        cube: :class:`~iris.cube.Cube`

        staggering: :str
            Either 'corner' or ''.
        """
        self.lats = None
        self.lons = None
        self.inlon_coord = cube.coord(axis="x")
        self.inlat_coord = cube.coord(axis="y")

        is_bounded = self.inlon_coord.has_bounds() and self.inlat_coord.has_bounds()
        if staggering == "corner" and not is_bounded:
            msg = "Must provide bounds for corner staggering"
            raise ValueError(msg)

        # Ensure we have a suitable coordinate system present.
        ants.utils.cube.set_crs(cube)

        from_crs = self.inlon_coord.coord_system.as_cartopy_crs()
        # Transform the rotated pole grid to true lat/lon.
        to_crs = ccrs.Geodetic()
        yy, xx = self._get_2d_latlon(staggering)
        xyz = to_crs.transform_points(from_crs, xx, yy)
        self.lats = xyz[..., 1]
        self.lons = xyz[..., 0]

    def get_latitude(self):
        """
        Get the true latitudes.

        Returns
        -------
        : class:`~numpy.array`
           The latitudes.
        """
        return self.lats

    def get_longitude(self):
        """
        Get the true longitudes.

        Returns
        -------
        : class:`~numpy.array`
           The longitudes.
        """
        return self.lons

    def _get_2d_latlon(self, staggering):
        ndims = len(self.inlat_coord.shape)

        if ndims == 1:
            if staggering == "corner":
                lat1d = self._get_1d_corner(self.inlat_coord)
                lon1d = self._get_1d_corner(self.inlon_coord)
            else:
                lat1d = self.inlat_coord.points
                lon1d = self.inlon_coord.points
            lons, lats = np.meshgrid(lon1d, lat1d)

        elif ndims == 2:
            if staggering == "corner":
                latbnd = self.inlat_coord.bounds[0, 0]
                lonbnd = self.inlon_coord.bounds[0, 0]
                clockwise = not LinearRing(
                    [(lon, lat) for lon, lat in zip(lonbnd, latbnd)]
                ).is_ccw
                lats = self._get_2d_corner(self.inlat_coord, clockwise)
                lons = self._get_2d_corner(self.inlon_coord, clockwise)
            else:
                lats = self.inlat_coord.points
                lons = self.inlon_coord.points
        else:
            msg = "Only 1d and 2d horizontal coordinates are supported."
            raise ValueError(msg)

        return lats, lons

    def _get_1d_corner(self, coord):
        n = coord.shape[0]
        x = np.zeros((n + 1,), np.float64)
        x[:-1] = coord.bounds[:, 0]
        x[-1] = coord.bounds[-1, 1]
        return x

    def _get_2d_corner(self, coord, clockwise):
        # ESMPy expects unique nodes, while bounds have many duplicates
        # (adjacent cells have 2 common bounds for contiguous data).  This
        # method slices the bounds arrays to eliminate that duplication and to
        # yield a nodes array with every node included once.
        test_bounds = coord.bounds[0, :2]
        allclose = ants.utils.ndarray.allclose

        # test_bounds is the first two cells:
        #
        # X ---- B ---- X
        # |      |      |
        # |  1   |   2  |
        # |      |      |
        # X ---- A ---- X
        #
        # We then test where the bounds array starts by comparing which
        # starting position ends up with the bounds A and B being shared
        # between the two cells.
        bottom_left_start = allclose(
            test_bounds[0, 1:3], np.array([test_bounds[1, 0], test_bounds[1, 3]])
        )
        bottom_right_start = allclose(test_bounds[0, 0:2], test_bounds[1, 2:])
        top_right_start = allclose(
            np.array([test_bounds[0, 0], test_bounds[0, 3]]), test_bounds[1, 1:3]
        )
        # Default to top left start for bounds
        topleft = 0
        if bottom_left_start:
            topleft = 3
        elif bottom_right_start:
            topleft = 2
        elif top_right_start:
            topleft = 1

        bounds = coord.bounds
        if clockwise:
            # Adjust indexing where clockwise.
            topleft += 1
        clockwise = -1 if clockwise else 1

        m, n = coord.shape[0], coord.shape[-1]
        xx = np.zeros((m + 1, n + 1), np.float64)
        xx[:-1, :-1] = bounds[:, :, topleft % 4]
        xx[-1, :-1] = bounds[-1, :, (topleft + (clockwise * 1)) % 4]
        xx[-1, -1] = bounds[-1, -1, (topleft + (clockwise * 2)) % 4]
        xx[:-1, -1] = bounds[:, -1, (topleft + (clockwise * 3)) % 4]

        return xx


class _BoxIterator:
    """
    Box iterator is a class that allows one to iterate over the cells of
    boxes in any number of dimensions.

    """

    def __init__(self, dims, row_major=True):
        """
        Constructor

        Parameters
        ----------
        dims: : list of dimensions along each axis
        row_major: :True if row major, False if column major

        """
        self.dims = dims
        self.ntot = functools.reduce(operator.mul, self.dims, 1)
        self.ndims = len(self.dims)
        self.big_index = -1
        self.dim_prod = np.array([1 for i in range(self.ndims)])
        if row_major:
            # row major
            for i in range(self.ndims - 2, -1, -1):
                self.dim_prod[i] = self.dim_prod[i + 1] * self.dims[i + 1]
        else:
            # column major
            for i in range(1, self.ndims):
                self.dim_prod[i] = self.dim_prod[i - 1] * self.dims[i - 1]

    def __iter__(self):
        return self

    def __next__(self):
        if self.big_index < self.ntot - 1:
            self.big_index += 1
            return self
        else:
            raise StopIteration

    def get_indices(self):
        """
        Return
        ------
         current index set
        """
        return self.get_indices_from_big_index(self.big_index)

    def get_big_index(self):
        """
        Return
        ------
         current big index
        """
        return self.big_index

    def get_indices_from_big_index(self, big_index):
        """
        Get index set from given big index.

        Parameters
        ----------
        big_index: : big index

        Return
        ------
          index set

        Note
        ----
        no checks are performed to ensure that the returned big index is valid.

        """
        indices = np.array([0 for i in range(self.ndims)])
        for i in range(self.ndims):
            indices[i] = big_index // self.dim_prod[i] % self.dims[i]
        return indices

    def get_big_index_from_indices(self, indices):
        """
        Get the big index from a given set of indices.

        Parameters
        ----------
        indices: : index set

        Return
        ------
         big index

        Note
        ----
          no checks are performed to ensure that the returned indices are valid
        """
        return functools.reduce(
            operator.add, [self.dim_prod[i] * indices[i] for i in range(self.ndims)], 0
        )

    def reset(self):
        """
        Reset big index.
        """
        self.big_index = -1

    def get_dims(self):
        """
        Get the axis dimensions.

        Return
        ------
         return list of dimensions
        """
        return self.dims

    def is_big_index_valid(self, big_index):
        """
        Test if big index is valid.

        Parameters
        ---------
        big_index: : big index

        Return
        ------
        True if big index is in range, False otherwise
        """
        return big_index < self.ntot and big_index >= 0

    def are_indices_valid(self, inds):
        """
        Test if indices are valid.

        Parameters
        ----------
        @param inds index set

        Return
        ------
          True if valid, False otherwise
        """
        return functools.reduce(
            operator.and_,
            [inds[d] < self.dims[d] and inds[d] >= 0 for d in range(self.ndims)],
            True,
        )


[docs] class ESMFRegridder(object):
[docs] def __init__(self, src_cube, target_cube, **kwargs): """ Regridding using ESMF via ESMPY. Suitable for general curvilinear grids. Parameters ---------- src_cube : :class:`~iris.cube.Cube` Defining the source grid. Must have latitude and longitude coordinates. Latitude and longitude can be axes (iris.coords.DimCoord) or auxilliary coordinates (iris.coords.AuxCoord) -- lat/lon axes will be converted to iris.coords.AuxCoord if need be. The cube can have additional axes, e.g. elevation, time, etc., data will be interpolated linearly along those axes. target_cube : :class:`~iris.cube.Cube` Defining the target grid. Same conditions as for src_cube apply for the coordinates. method : :class:`str`, optional Defining the regridding method. Currently supported methods are: "areaWeighted" (default) persistent_cache : :obj:`bool`, optional Determine whether cache persists between runs. That is, whether the cache persists after the program is terminated and will be available for successive runs of the application. The cache location is determined by the TMPDIR environmental variable. Cache filenames are derived from source-target grid metadata checksums. Default is False (that is, cache is destroyed with the class). """ keywarg_diff = set(kwargs.keys()) - set(["method", "persistent_cache"]) if keywarg_diff: msg = "unexpected keyword argument {}" raise ValueError(msg.format(keywarg_diff)) if esmpy is None: raise _ESMPY_IMPORT_ERROR _supported_cube_check(src_cube) _supported_cube_check(target_cube) # Set some parameters. self.handle = None self.coordSystem = esmpy.api.constants.CoordSys.SPH_DEG self.method = esmpy.api.constants.RegridMethod.CONSERVE self.stagger = esmpy.StaggerLoc.CENTER method = kwargs.get("method", "areaweighted") if method.lower() != "areaweighted": raise ValueError("Currently only area weighted regridding " "supported.") # Simply return if the src and tgt grids are identical. if (src_cube.coord(axis="x") == target_cube.coord(axis="x")) and ( src_cube.coord(axis="y") == target_cube.coord(axis="y") ): return _source_cube_sanity_check(src_cube) # Build the 2D esmf grid and field objects. self.esmpy_src_grid, self.esmpy_src_field = self._build_field(src_cube) self.esmpy_tgt_grid, self.esmpy_tgt_field = self._build_field(target_cube) # Compute/read the weights following ESMPy weights tutorial. See # ESMPy docs for details of arguments. self._cache_fnme = self._gen_cache_filename([src_cube, target_cube]) self._persistent_cache = bool(kwargs.get("persistent_cache", False)) if not os.path.isfile(self._cache_fnme): # No existing cache so have ESMF generate it. self.handle = esmpy.api.regrid.Regrid( self.esmpy_src_field, self.esmpy_tgt_field, regrid_method=self.method, line_type=esmpy.api.constants.LineType.CART, unmapped_action=esmpy.api.constants.UnmappedAction.IGNORE, ignore_degenerate=True, filename=self._cache_fnme, ) else: # Utilise the existing cache. try: self.handle = esmpy.api.regrid.RegridFromFile( self.esmpy_src_field, self.esmpy_tgt_field, self._cache_fnme ) except ValueError as err: msg = " Problem attempting to utilise cache file {}" msg = msg.format(self._cache_fnme) err_msg = list(err.args) err_msg[0] += msg err.args = err_msg raise # Get the latitude/longitude self.tgt_latlon = self._get_latlon_from_cube(target_cube) # Store reference to target cube, sets the output grid. self.tgt_cube = target_cube # Record source grid used in the calculation of the weights. self.src_grid = [src_cube.coord(axis="x"), src_cube.coord(axis="y")]
@property def cache(self): """ Return the cache produced from the esmf regrid. Return the deferred columns, rows and weights esmf cache: where columns correspond to source cell indices; rows correspond to target cell indices and weights corresponding to the column-row mapping. All three will match in size. The weights correspond to the fraction of the target cell which is overlapped by the given source cell. See the following illustration:: |-| - Source |---------| - Target Here the weight between the source and target cell is 0.25 as the source cell covers 25% of the target cell. |---------| - Source |-| - Target Here the weight between the source and target cell is 1 as the source cell covers 100% of the target cell. See Also -------- http://earthsystemmodeling.org/docs/release/ESMF_8_3_1/ESMF_refdoc/node3.html#SECTION03029000000000000000 : for esmf weight only file specification. Note ---- ESMPy currently only creates a "Weight Only Weight File" and so doesn't contain the destination fraction (frac_b). Points are assumed not to extend beyond the grid. Returns ------- : `numpy.ndarray`, `numpy.ndarray`, `numpy.ndarray` source indices (columns), target indices (rows), weights. Examples -------- Example usage of cache in performing area weighted regrid calculation:: regridder = self.scheme.regridder(source, target_grid) columns, rows, weights = regridder.cache source_flattened = source.data.reshape(-1) result = target_grid.copy( np.zeros(target_grid.shape, dtype='float')) result_flattened = result.data.reshape(-1) for ind in range(rows.size): row = rows[ind] column = columns[ind] result_flattened[row] = ( result_flattened[row] + (weights[ind]*source_flattened[column])) Example as above but with sparse array usage:: result = target.copy(np.zeros(target.shape, dtype='float')) sparse_array = scipy.sparse.coo_matrix( (weights, (rows, columns)), shape=(np.prod([tgt.data.shape[1], tgt.data.shape[2]]), src.data.size)).tocsc() result.data.reshape(-1)[:] = (sparse_array * src.data.reshape(-1)) """ columns, rows, weights = ants.io.load.load( self._cache_fnme, ["col", "row", "S"] ) # esmf is Fortran based, which means we must convert indices to C # ordered indices. # - Convert to 0 based indexing. # - Convert from between column and row based indexing. c_columns = np.unravel_index( columns.data - 1, self.esmpy_src_field.data.shape, order="F" ) c_columns = np.ravel_multi_index( c_columns, self.esmpy_src_field.data.shape, order="C" ) c_rows = np.unravel_index( rows.data - 1, self.esmpy_tgt_field.data.shape, order="F" ) c_rows = np.ravel_multi_index( c_rows, self.esmpy_tgt_field.data.shape, order="C" ) return c_columns, c_rows, weights.data def _gen_cache_filename(self, cubes): m = hashlib.md5() m.update(str(self.method).encode("utf-8")) m.update(str(self.__class__).encode("utf-8")) for cube in cubes: for coord in [cube.coord(axis="x"), cube.coord(axis="y")]: ncoord = coord.copy() ncoord.var_name = None m.update(str(ncoord).encode("utf-8")) return os.path.join(tempfile.gettempdir(), m.hexdigest() + ".nc") def __del__(self): # Free memory for item in [ "handle", "esmpy_src_field", "esmpy_tgt_field", "esmpy_src_grid", "esmpy_tgt_grid", ]: handle = getattr(self, item, None) if handle is not None: handle.destroy() if not getattr(self, "_persistent_cache", True) and os.path.isfile( self._cache_fnme ): os.remove(self._cache_fnme)
[docs] def __call__(self, inp_cube): """ Apply the interpolation weights to the source field. Parameters ---------- inp_cube : :class:`~iris.cube.Cube` Defining the input cube which has same horizontal grid as src_cube, see constructor. Returns ------- : class:`~iris.cube.Cube` Target cube with regridded data. """ # Do not perform an unnecessary regrid when the source and target are # identical. if not hasattr(self, "tgt_cube"): return inp_cube # Populating and check coordinate system information. ants.utils.cube.set_crs(inp_cube) equal_x = ants.utils.coord.relaxed_equality( inp_cube.coord(axis="x"), self.src_grid[0] ) equal_y = ants.utils.coord.relaxed_equality( inp_cube.coord(axis="y"), self.src_grid[1] ) if not equal_x or not equal_y: msg = ( "The provided source cube has a horizontal grid which is " "not identical to that used to derive the weights." ) raise ValueError(msg) _supported_cube_check(inp_cube) _source_cube_sanity_check(inp_cube) # Start and end data indices when running in parallel. src_ib = self.esmpy_src_grid.lower_bounds[esmpy.StaggerLoc.CORNER][0] src_ie = self.esmpy_src_grid.upper_bounds[esmpy.StaggerLoc.CORNER][0] src_jb = self.esmpy_src_grid.lower_bounds[esmpy.StaggerLoc.CORNER][1] src_je = self.esmpy_src_grid.upper_bounds[esmpy.StaggerLoc.CORNER][1] tgt_ib = self.esmpy_tgt_grid.lower_bounds[esmpy.StaggerLoc.CORNER][0] tgt_ie = self.esmpy_tgt_grid.upper_bounds[esmpy.StaggerLoc.CORNER][0] tgt_jb = self.esmpy_tgt_grid.lower_bounds[esmpy.StaggerLoc.CORNER][1] tgt_je = self.esmpy_tgt_grid.upper_bounds[esmpy.StaggerLoc.CORNER][1] # Create the output cube from the target horizontal coordinates and the # input cube axes. out_cube = self._create_output_cube(inp_cube) # Check, input and output cubes must have lat/lon coords. self._check_cubes(inp_cube, out_cube) # Collect all the coordinate indices that are neither longitude nor # latitude, same for input and output cubes. other_inds = self._get_other_coord_indices(out_cube) # All the dimensions other than lat/lon, same for input and output # cubes. other_dims = [out_cube.data.shape[i] for i in other_inds] # Collect all the indices and their standard name. inp_name2i = self._get_coord_indices(inp_cube) inp_skip = tuple(set(inp_name2i["longitude"] + inp_name2i["latitude"])) out_name2i = self._get_coord_indices(out_cube) out_skip = tuple(set(out_name2i["longitude"] + out_name2i["latitude"])) # Iterate over the non lat/lon dimensions of the target field. for it in _BoxIterator(other_dims): # Set of indices, excluding the lat/lon indices. inds = it.get_indices() # Slicing operator for the input/ouput field. These have ':' in # place of the lat/lon indices and integers for all other axes. inp_inds_ext = self._get_extended_slice(inp_skip, inds, inp_cube) out_inds_ext = self._get_extended_slice(out_skip, inds, out_cube) # Get the input data for that slice. inp_data = inp_cube.data[inp_inds_ext] # Regrid the mask. out_mask = None if np.ma.is_masked(inp_cube.data): self.esmpy_src_field.data[src_ib:src_ie, src_jb:src_je] = ( inp_cube.data.mask[inp_inds_ext] ) self.handle(self.esmpy_src_field, self.esmpy_tgt_field) # CP: Not now, but in future we would define this with a # tolerance (mdtol). out_mask = self.esmpy_tgt_field.data[tgt_ib:tgt_ie, tgt_jb:tgt_je] > 0.0 # Regrid the field. inp_data = inp_cube.data[inp_inds_ext] self.esmpy_src_field.data[src_ib:src_ie, src_jb:src_je] = inp_data self.handle(self.esmpy_src_field, self.esmpy_tgt_field) # Apply mask and copy into cube container. out_cube.data[out_inds_ext] = self.esmpy_tgt_field.data[ tgt_ib:tgt_ie, tgt_jb:tgt_je ] if out_mask is not False: out_cube.data[out_inds_ext] = np.ma.masked_where( out_mask, out_cube.data[out_inds_ext] ) return out_cube
def _check_cubes(self, src_cube, tgt_cube): has_lat, has_lon = self._check_has_latitudes_and_longitudes(src_cube) if not has_lat: msg = "No latitude in source cube" raise ValueError(msg) if not has_lon: msg = "No longitude in source cube" raise ValueError(msg) has_lat, has_lon = self._check_has_latitudes_and_longitudes(tgt_cube) if not has_lat: msg = "No latitude in target cube" raise ValueError(msg) if not has_lon: msg = "No longitude in target cube" raise ValueError(msg) def _check_has_latitudes_and_longitudes(self, cube): # Must have latitudes and longitudes. has_lat, has_lon = False, False for coord in cube.coords(): if isinstance(coord.standard_name, str) or isinstance( coord.standard_name, str ): if coord.standard_name.find("latitude") >= 0: has_lat = True if coord.standard_name.find("longitude") >= 0: has_lon = True return has_lat, has_lon def _get_latlon_from_cube(self, cube): # # Extract the latitude and longitude coordinates from the cube. # Apply coordinate transformation to convert a dim coord to # a 2d aux coordinate if need be. # data = { "lat_coord": None, "lat_data": None, "lon_coord": None, "lon_data": None, } data["lon_coord"], data["lat_coord"] = ants.utils.cube.horizontal_grid(cube) # for the data, no need to use the bounds to compute the lats/lons staggering = "" extractor = _LatLonExtractor(cube, staggering=staggering) data["lat_data"] = extractor.get_latitude() data["lon_data"] = extractor.get_longitude() return data def _create_output_cube(self, inp_cube): # Create the output cube from the target cube and the input cube. # Latitudes and longitudes come from the tgt_cube passed to the # constructor. All other axes come from the input cube. # Gather the output cube's coordinates from self.tgt_cube and inp_cube inp_coords = inp_cube.coords() ndims = len(inp_cube.data.shape) out_data_shape = [None for i in range(ndims)] out_coords = [] out_coord_dims = [] offset = 0 for j in range(len(inp_coords)): coord = inp_coords[j] dims = inp_cube.coord_dims(coord) std_name = coord.standard_name is_string = isinstance(std_name, str) or isinstance(std_name, str) if is_string and std_name.find("latitude") >= 0: # Take latitude from the target cube. out_coords.append(self.tgt_latlon["lat_coord"]) if len(dims) > 1: for i in range(len(dims)): out_data_shape[dims[i]] = self.tgt_latlon["lat_data"].shape[i] else: out_data_shape[dims[0]] = self.tgt_latlon["lat_data"].shape[offset] offset += 1 elif is_string and std_name.find("longitude") >= 0: # Take longitude from the target cube. out_coords.append(self.tgt_latlon["lon_coord"]) if len(dims) > 1: for i in range(len(dims)): out_data_shape[dims[i]] = self.tgt_latlon["lon_data"].shape[i] else: out_data_shape[dims[0]] = self.tgt_latlon["lon_data"].shape[offset] offset += 1 else: # Take coordinate from the input cube. out_coords.append(coord) for i in range(len(dims)): out_data_shape[dims[i]] = coord.shape[i] out_coord_dims.append(dims) # Create the output cube with data initialised to zero. data = np.ma.zeros(out_data_shape, np.float64) out_cube = iris.cube.Cube( data, standard_name=inp_cube.standard_name, var_name=inp_cube.var_name ) # Add attributes and rename. out_cube.attributes = inp_cube.attributes.copy() out_cube.rename(inp_cube.name()) # Build 2d data dimensions for lat and lon in case input cube had 1d # dim coord and target has 2d aux coords. latlon_dims = [] inp_lon_coord, inp_lat_coord = ants.utils.cube.horizontal_grid(inp_cube) lat_dims = inp_cube.coord_dims(inp_lat_coord) lon_dims = inp_cube.coord_dims(inp_lon_coord) offset = 0 latlon_dims.append(lat_dims[offset]) if len(lat_dims) > 1: offset += 1 latlon_dims.append(lon_dims[offset]) # Add coordinates. offset = 0 for j in range(len(out_coords)): coord = out_coords[j] dims = out_coord_dims[j] if isinstance(coord, iris.coords.DimCoord): # DimCoord. out_cube.add_dim_coord(coord, data_dim=dims[offset]) if len(dims) > 1: offset += 1 elif isinstance(coord, iris.coords.AuxCoord): # AuxCoord. if hasattr(dims, "__len__") and len(dims) == len(coord.shape): out_cube.add_aux_coord(coord, data_dims=dims) else: # Input cube has a dim coord while target has 2d aux coord out_cube.add_aux_coord(coord, data_dims=latlon_dims) # copy the attribute from the input cube out_cube.metadata = inp_cube.metadata _remove_undesirable_attributes(out_cube) return out_cube def _get_extended_slice(self, skip, inds, cube): # # Build the extended slice, using inds as index set for the axes # E.g. for skip=(0, 2) and inds = [a, b, c], returns (:, a, :, b, c), # that is the slice through indices 0 and 2 and take a, b, c for the # remaining indices. # # Note: May not run in parallel! slce = slice(0, None, None) n = len(skip) + len(inds) coord_inds_ext = [] j = 0 for i in range(n): if i in skip: coord_inds_ext.append(slce) else: coord_inds_ext.append(inds[j]) j += 1 inds_ext = coord_inds_ext return tuple(inds_ext) def _build_field(self, cube): # # Build the esmpy field object. # staggering = "corner" if self.method != esmpy.api.constants.RegridMethod.CONSERVE: # Need to pass corner coordinates in all cases. When the field is # cell centred this requires using bounds n (so staggering is # corner). For all other regridding methods we will not use the # bounds (so staggering is ''). staggering = "" # Get the true latitudes and longitudes on cell vertices. extractor = _LatLonExtractor(cube, staggering) lats = extractor.get_latitude() lons = extractor.get_longitude() # Create the grid. cellDims = np.array([lons.shape[0] - 1, lats.shape[1] - 1]) grid = esmpy.Grid(max_index=cellDims, coord_sys=self.coordSystem) # Allocate space for the vertices, esmpy wants the first coordinate to # be longitudes. grid.add_coords(staggerloc=esmpy.StaggerLoc.CORNER, coord_dim=0) # No need to add lats, it will be added automatically with lons # Get pointers to the esmf coordinates. lonPoint = grid.get_coords(coord_dim=0, staggerloc=esmpy.StaggerLoc.CORNER) latPoint = grid.get_coords(coord_dim=1, staggerloc=esmpy.StaggerLoc.CORNER) # When esmpy runs in parallel, the start/end indices may be other than # 0,-1. # CP: Is esmpy running in parallel being tested?? I suggest just # removing it, and we can add it once we get something that just # works onto trunk. I think there is greater benefit to getting # this onto trunk soon and looking at another iteration of # development when the need arises. Perhaps I'm wrong?? ibeg0 = grid.lower_bounds[esmpy.StaggerLoc.CORNER][0] iend0 = grid.upper_bounds[esmpy.StaggerLoc.CORNER][0] ibeg1 = grid.lower_bounds[esmpy.StaggerLoc.CORNER][1] iend1 = grid.upper_bounds[esmpy.StaggerLoc.CORNER][1] lonPoint[...] = lons[ibeg0:iend0, ibeg1:iend1] latPoint[...] = lats[ibeg0:iend0, ibeg1:iend1] # Build the field, stagger is either CENTER or CORNER depending # on the method of interpolation. (Might consider choosing the method # given the cell_method.) dtype = esmpy.api.constants.TypeKind.R8 # always use double precision field = esmpy.Field(grid, staggerloc=self.stagger, typekind=dtype) # Note: not setting the field data at this point. return grid, field def _get_coord_indices(self, cube): # # Associates a coordinate name to an index. # res = {} coords = cube.coords() for coord in coords: std_name = coord.standard_name res[std_name] = cube.coord_dims(coord) """ for i in range(len(coords)): std_name = coords[i].standard_name res[std_name] = i """ if "grid_longitude" in res: res["longitude"] = res["grid_longitude"] if "grid_latitude" in res: res["latitude"] = res["grid_latitude"] return res def _get_other_coord_indices(self, cube): # # Collect all the indices that are neither longitude nor latitude # (whether mapped or not). res = [] coords = cube.coords() for i in range(len(coords)): coord = coords[i] std_name = coord.standard_name dims = cube.coord_dims(coord) if not ( std_name == "longitude" or std_name == "grid_longitude" or std_name == "latitude" or std_name == "grid_latitude" ): for j in dims: res.append(j) return tuple(set(res))
[docs] class ConservativeESMF(object): """ ESMF regridding scheme using esmpy. Regridding suitable for general curvilinear grids. """
[docs] def __init__(self): self._method = "areaweighted"
[docs] def regridder(self, src_grid_cube, target_grid_cube, **kwargs): """ Creates an ESMF regridding scheme using esmpy. Parameters ---------- src_grid_cube : :class:`~iris.cube.Cube` Defining the source grid. target_grid_cube : :class:`~iris.cube.Cube` Defining the target grid. Returns ------- : :class:`ESMFRegridder` Callable with the interface `callable(cube)` where `cube` is a cube with the same grid as `src_grid_cube` that is to be regridded to the `target_grid_cube`. """ return ESMFRegridder( src_grid_cube, target_grid_cube, method=self._method, **kwargs )