# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of UG-ANTS and is released under the BSD 3-Clause license.
# See LICENSE.txt in the root of the repository for full licensing details.
"""Implementation for the regrid application."""
import os
from pathlib import Path
from typing import Literal
import esmf_regrid
import numpy as np
from esmf_regrid.experimental.io import load_regridder, save_regridder
from esmf_regrid.experimental.unstructured_scheme import GridToMeshESMFRegridder
from iris.cube import Cube, CubeList
from iris.experimental.ugrid import Mesh, save_mesh
from iris.fileformats.netcdf import save as save_netcdf
from ugants.abc import Application
from ugants.io import save
from ugants.io.load import cf as load_cf
from ugants.io.load import is_netcdf
from ugants.io.save import (
_check_filepath_extension as verify_output_to_netcdf,
)
from ugants.io.save import (
ugrid as save_ugrid,
)
from ugants.regrid import band_utils
from ugants.utils.cube import as_cubelist
[docs]
class Regrid(Application):
"""Regrid regular grid data to an unstructured mesh.
Parameters
----------
source : :class:`iris.cube.CubeList`
The regular gridded source data to be regridded.
target_mesh : :class:`iris.experimental.ugrid.Mesh`
The UGrid target mesh.
horizontal_regrid_scheme : :obj:`str`
The horizontal regrid scheme to be used. Supported schemes are
"conservative", "bilinear", "nearest".
tolerance : :obj:`float`
Tolerance of missing data.
The value returned in each element of the returned array will be masked if the
fraction of masked data exceeds this tolerance.
If not provided, the default tolerance value is zero.
This option is not available for the "nearest" scheme.
input_weights
Path to cached input weights.
An optional file containing the pre-generated weights for the mesh
being used for regridding.
output_weights
An optional target path for output cached weights. Using pre-cached weights
makes the regridding process less computationally expensive.
Raises
------
ValueError
If a ``tolerance`` is provided and the ``horizontal_regrid_scheme``
is "nearest".
Note
----
The data is always regridded to the faces of the unstructured grid.
"""
results: CubeList = None
"""The source data regridded onto the faces of the target mesh."""
source: CubeList = None
"""The source data to be regridded."""
target_mesh: Mesh = None
"""The mesh to regrid to."""
_loader = load_cf
def __init__(
self,
source: CubeList,
target_mesh: Mesh,
horizontal_regrid_scheme: Literal["conservative", "bilinear", "nearest"],
tolerance: float = 0,
input_weights: str = "",
output_weights: str = "",
):
if tolerance and horizontal_regrid_scheme == "nearest":
raise ValueError(
"The 'tolerance' option is not available for regrid scheme 'nearest'"
)
source = as_cubelist(source)
if len(source) > 1:
_validate_source_grids(source)
if input_weights != "" and output_weights != "":
raise ValueError(
"Only one of input_weights and output_weights can be provided"
)
self.source = source
self.target_mesh = target_mesh
self.horizontal_regrid_scheme = horizontal_regrid_scheme
self.tolerance = tolerance
self.input_weights = input_weights
self.output_weights = output_weights
[docs]
def run(self):
"""Regrid :attr:`source` to :attr:`target_mesh`.
The result of the regrid is stored in :attr:`results`.
"""
source_cube = self.source[0]
if self.input_weights:
is_netcdf(self.input_weights)
self.regridder = load_regridder(Path(self.input_weights))
_validate_input_weights(
self.regridder,
self.tolerance,
self.horizontal_regrid_scheme,
)
else:
self.regridder = GridToMeshESMFRegridder(
source_cube,
self.target_mesh,
method=self.horizontal_regrid_scheme,
tgt_location="face",
mdtol=self.tolerance,
)
self.results = CubeList(self.regridder(cube) for cube in self.source)
[docs]
def save(self):
"""Save ``self.results`` to ``self.output``."""
verify_output_to_netcdf(Path(self.output))
save_ugrid(self.results, self.output)
if self.output_weights:
save_regridder(self.regridder, str(self.output_weights))
def _validate_input_weights(
regridder: GridToMeshESMFRegridder,
tolerance: float,
horizontal_regrid_scheme: Literal["conservative", "bilinear", "nearest"],
):
if regridder.mdtol != tolerance:
raise ValueError(
"Tolerance value of input_weights does not match the value"
" provided on command line."
)
if regridder.method != horizontal_regrid_scheme:
raise ValueError(
"Regrid scheme of input_weights does not match the scheme "
"provided on command line."
)
[docs]
class RegridMeshToMesh(Application):
"""Regrid unstructured mesh data to an unstructured mesh of a different resolution.
Parameters
----------
source
The unstructured mesh data to be regridded.
target_mesh
The UGrid target resolution mesh.
horizontal_regrid_scheme
The horizontal regrid scheme to be used. Supported schemes are
"conservative", "bilinear", "nearest".
tolerance
Tolerance of missing data.
The value returned in each element of the returned array will be masked if the
fraction of masked data exceeds this tolerance.
If not provided, the default tolerance value is zero.
This option is not available for the "nearest" scheme.
Raises
------
ValueError
If a ``tolerance`` is provided and the ``horizontal_regrid_scheme``
is "nearest".
Note
----
The data is always regridded to the faces of the unstructured mesh.
"""
results: CubeList = None
"""The source data regridded onto the faces of the target mesh."""
def __init__(
self,
source: CubeList,
target_mesh: Mesh,
horizontal_regrid_scheme: Literal["conservative", "bilinear", "nearest"],
tolerance: float = 0,
):
if tolerance and horizontal_regrid_scheme == "nearest":
raise ValueError(
"The 'tolerance' option is not available for regrid scheme 'nearest'"
)
source = as_cubelist(source)
if len(source) > 1:
_validate_source_meshes(source)
self.source = source
self.target_mesh = target_mesh
self.horizontal_regrid_scheme = horizontal_regrid_scheme
self.tolerance = tolerance
[docs]
def run(self):
"""Regrid ``self.source`` to ``self.target_mesh``."""
regridder_lookup = {
"conservative": esmf_regrid.ESMFAreaWeightedRegridder,
"bilinear": esmf_regrid.ESMFBilinearRegridder,
"nearest": esmf_regrid.ESMFNearestRegridder,
}
regridder_kwargs = {
"src": self.source[0],
"tgt": self.target_mesh,
"tgt_location": "face",
}
# The nearest neighbour scheme does not accept a mdtol argument
# but other schemes do
if self.horizontal_regrid_scheme != "nearest":
regridder_kwargs["mdtol"] = self.tolerance
# Select the appropriate regridder from the lookup, and instantiate with
# appropriate keyword arguments. This regridder can then be reused for all cubes
regridder = regridder_lookup[self.horizontal_regrid_scheme](**regridder_kwargs)
self.results = CubeList(regridder(source) for source in self.source)
[docs]
class SplitGridToMeshByLatitude(Application):
"""Split the provided regular gridded source and target mesh into latitude bands.
Parameters
----------
source : :class:`~iris.cube.CubeList`
The global, regular gridded source data to be split.
target_mesh : :class:`~iris.experimental.ugrid.Mesh`
The target mesh to be split.
number_of_bands : int
Number of latitude bands to split by.
"""
_loader = load_cf
results: list[CubeList] = None
"""A list of CubeLists, one for each latitude band.
Each CubeList contains the same number of cubes as the source CubeList.
The domain of each cube completely covers the corresponding mesh band in
:attr:`mesh_bands`. Padding is added ensure that the mesh band is
fully enclosed, so there will be overlap between adjacent source bands.
See :func:`~ugants.regrid.band_utils.cube_subset_latitude_bounds` for more details
on how the padding is calculated.
An attribute ``band_number`` is added to distinguish each cube."""
mesh_bands: list[Mesh]
"""Each mesh covers a latitude band, approximately evenly spaced.
There are no overlapping cells between any two mesh bands, so each mesh
band covers a unique region. Together, the mesh bands cover the entire
target mesh."""
mesh_mapping_cube: Cube
"""A UGrid cube constructed from the ``target_mesh``. The data maps each cell
to its corresponding latitude band. For example, if a cell has a value of 1,
then it belongs in latitude band 1."""
output: str = None
"""The output **directory** to which to write the
:attr:`results`, :attr:`mesh_bands` and
:attr:`mesh_mapping_cube`.
"""
def __init__(self, source: CubeList, target_mesh: Mesh, number_of_bands: int):
source = as_cubelist(source)
if len(source) > 1:
_validate_source_grids(source)
_validate_source_is_global(source[0])
_validate_number_of_bands(number_of_bands)
self.source = source
self.target_mesh = target_mesh
self.number_of_bands = number_of_bands
[docs]
def run(self):
"""Run the application.
The source and target are split into bands of approximately equal latitude.
There is **no** overlap between target bands, i.e. a cell in the original target
mesh will appear in one target band only. The mapping between target
mesh cells and band number is recorded in the :attr:`mesh_mapping_cube`.
There **is** overlap between source bands, i.e. a cell in the original source
may appear in more than one source band. This is because the source domain must
extend beyond the target domain in order to capture all the required data for
regridding.
The following attributes are set by this method:
* :attr:`results`
* :attr:`mesh_bands`
* :attr:`mesh_mapping_cube`
"""
source_cube = self.source[0]
target_min_latitude = min(self.target_mesh.node_coords.node_y.points)
target_max_latitude = max(self.target_mesh.node_coords.node_y.points)
band_bounds = band_utils.generate_band_bounds(
start=target_min_latitude,
stop=target_max_latitude,
n_bands=self.number_of_bands,
)
self.mesh_mapping_cube = band_utils.mesh_to_cube(self.target_mesh)
# A boolean array of shape (n_bands, n_faces)
mesh_indices_per_band = np.array(
[
band_utils.find_cell_centres_within_latitude_bounds(
self.mesh_mapping_cube, bounds
)
for bounds in band_bounds
],
)
# An integer array of shape (n_faces,) labelling each face according to its band
self.mesh_mapping_cube.data = mesh_indices_per_band.argmax(0)
self.mesh_bands = [
band_utils.subset_mesh_cube_by_indices(self.mesh_mapping_cube, indices).mesh
for indices in mesh_indices_per_band
]
# A list of band bounds tuples (lower_latitude, upper_latitude),
# of length n_bands
latitude_bounds_for_source = [
band_utils.cube_subset_latitude_bounds(
subsetted_mesh, source_cube.coord("latitude")
)
for subsetted_mesh in self.mesh_bands
]
self.results = []
# 1. iterate over latitude bands
for band_number, band_bounds in enumerate(latitude_bounds_for_source):
latitude_band_cubelist = CubeList()
# 2. iterate over source cubes, extract latitude band from each source cube
for source_cube in self.source:
latitude_band_cube = band_utils.constrain_source_cube_latitude(
source_cube, band_bounds
)
latitude_band_cube.attributes["band_number"] = band_number
latitude_band_cubelist.append(latitude_band_cube)
self.results.append(latitude_band_cubelist)
[docs]
def save(self):
"""Save the latitude bands to NetCDF.
Three types of file are saved to the directory specified by ``self.output``:
* :attr:`results`: ``number_of_bands`` such files are output, named
``source_band_{band_number}.nc``.
* :attr:`mesh_bands`: ``number_of_bands`` such files are output, named
``mesh_band_{band_number}.nc``.
* :attr:`mesh_mapping_cube`: only one such file is output, named
``mesh_band_mapping.nc``.
In total, :code:`2*number_of_bands + 1` files are output.
"""
if self.output is None:
raise ValueError("No output directory location has been set.")
if not hasattr(self, "mesh_mapping_cube"):
raise ValueError(
"The application has not yet been run, mesh_mapping_cube is not set."
)
if not hasattr(self, "mesh_bands"):
raise ValueError(
"The application has not yet been run, mesh_bands is not set."
)
if self.results is None:
raise ValueError(
"The application has not yet been run, results is not set."
)
save.ugrid(
self.mesh_mapping_cube, os.path.join(self.output, "mesh_band_mapping.nc")
)
for band_number, result, mesh_band in zip(
range(len(self.results)),
self.results,
self.mesh_bands,
strict=True,
):
save_mesh(
mesh_band, os.path.join(self.output, f"mesh_band_{band_number}.nc")
)
save_netcdf(
result,
os.path.join(self.output, f"source_band_{band_number}.nc"),
)
def _validate_source_grids(source: CubeList):
"""Check that all cubes have the same horizontal grid.
Does not check dimension ordering, only that the horizontal coordinates are equal.
Parameters
----------
source: iris.cube.CubeList
Cubes to be compared. Must be at least 2 cubes.
Raises
------
ValueError
If any pair of cubes have different x or y coordinates.
"""
reference_grid_x = source[0].coord(axis="x")
reference_grid_y = source[0].coord(axis="y")
for cube in source[1:]:
if (cube.coord(axis="x") != reference_grid_x) or (
cube.coord(axis="y") != reference_grid_y
):
raise ValueError("Not all source cubes have the same horizontal grid.")
def _validate_source_meshes(source: CubeList):
"""Check that all cubes have the same horizontal mesh.
Parameters
----------
source: iris.cube.CubeList
Cubes to be compared. Must be at least 2 cubes.
Raises
------
ValueError
If any pair of cubes have different meshes.
"""
for cube in source[1:]:
if cube.mesh != source[0].mesh:
raise ValueError("Not all source cubes have the same horizontal mesh.")
def _validate_source_cubelist_length(source_cubelist: CubeList):
"""Check that the source cubelist contains only one cube.
Parameters
----------
source_cubelist : CubeList
The source cubelist to validate
Raises
------
ValueError
If there is not only one cube in the source cubelist.
"""
number_source_cubes = len(source_cubelist)
if number_source_cubes != 1:
raise ValueError(f"Source contained {number_source_cubes} cubes, expected 1.")
def _validate_source_is_global(source_cube: Cube):
"""Check that the source cube is global.
The following checks are performed on the cube's horizontal coordinates:
* Longitude (axis="x") must be circular.
* Latitude (axis="y") bounds must extend from -90 to +90 degrees.
Parameters
----------
source_cube : Cube
The source cube to validate
Raises
------
ValueError
If the source data is not global
"""
longitude = source_cube.coord(axis="x").copy()
if not longitude.circular:
raise ValueError(
"The provided source is not global: longitude is not circular."
)
latitude = source_cube.coord(axis="y").copy()
if not latitude.has_bounds():
latitude.guess_bounds()
min_lat = latitude.bounds.min()
max_lat = latitude.bounds.max()
if (min_lat, max_lat) != (-90.0, 90.0):
raise ValueError(
f"The provided source is not global: latitude min = {min_lat}, "
f"latitude max = {max_lat}"
)
def _validate_number_of_bands(number_of_bands: int):
"""Check that the number of bands is greater than one.
Parameters
----------
number_of_bands : int
The number of bands to attempt to split by.
Raises
------
ValueError
If the number of bands is not greater than one.
"""
if number_of_bands < 2:
raise ValueError(
f"The number of bands must be greater than 1, got {number_of_bands}."
)
[docs]
class RecombineMeshBands(Application):
"""Recombine regridded latitude bands into a single cube.
Parameters
----------
mesh_mapping : iris.cube.CubeList
A single-element CubList which maps individual latitude bands to cells
in the target mesh. See also
:attr:`SplitGridToMeshByLatitude.mesh_mapping_cube`.
bands : iris.cube.CubeList
The regridded latitude bands to recombine into a single UGrid cube.
Each band cube must have a ``band_number`` attribute, which maps the
band to its location on the ``mesh_mapping``.
"""
results: CubeList = None
"""The recombined bands in a single cube. The cube's mesh is identical
to that of the mesh mapping, and the data is taken from the cubes in ``bands``."""
def __init__(self, mesh_mapping: CubeList, bands: CubeList):
self.mesh_mapping = as_cubelist(mesh_mapping)
self.bands = bands
self.names = {cube.name() for cube in self.bands}
# Validate that the mesh mapping has a single unstructured dimension
if self.mesh_mapping[0].mesh is None:
raise ValueError("The provided mesh_mapping does not contain a mesh.")
if (ndim := self.mesh_mapping[0].ndim) != 1:
raise ValueError(
f"The provided mesh_mapping should have 1 dimension, got {ndim}."
)
# Validate that the expected numbers of bands are provided for each variable
# Validate consistency in unstructured mesh dimension length
expected_max_band_number = int(self.mesh_mapping[0].data.max())
expected_band_numbers = list(range(expected_max_band_number + 1))
self._target_mesh_dim_length = self.mesh_mapping[0].shape[0]
for name in self.names:
band_numbers_by_name = sorted(
cube.attributes["band_number"] for cube in self.bands.extract(name)
)
if band_numbers_by_name != expected_band_numbers:
raise ValueError(
f"Inconsistent mesh bands provided for {name}: expected "
f"{expected_band_numbers}, got {band_numbers_by_name}"
)
# Sum of the number of points in the mesh dimension of each cube
bands_mesh_dim_lengths = [
cube.shape[cube.mesh_dim()] for cube in self.bands.extract(name)
]
bands_mesh_total_dim_length = sum(bands_mesh_dim_lengths)
if bands_mesh_total_dim_length != self._target_mesh_dim_length:
raise ValueError(
f"Inconsistent unstructured dimension lengths for {name} bands. "
"Provided bands have unstructured dimensions of lengths "
f"{bands_mesh_dim_lengths} giving a total of "
f"{bands_mesh_total_dim_length}, whereas mesh mapping has length "
f"{self._target_mesh_dim_length}."
)
[docs]
def run(self):
"""Recombine the latitude bands into a single cube.
The data from each latitude band cube in the ``bands``
CubeList are used to fill the corresponding cells in the target mesh,
according to the band cube's ``band_number`` attribute.
The ``mesh_mapping_cube`` provides the target mesh to be filled with data from
the regridded latitude bands. The data in ``mesh_mapping_cube`` describe which
latitude band is to be used to fill the cells.
For example, cells with value 2 in the ``mesh_mapping_cube``
are filled with data from band number 2.
This method sets the :attr:`results` attribute.
"""
self.results = CubeList(
self._recombine_single_variable(
self.bands.extract(name), self.mesh_mapping[0]
)
for name in self.names
)
def _recombine_single_variable(self, regrid_bands: CubeList, mesh_mapping: Cube):
reference_band = regrid_bands[0]
target_shape = list(reference_band.shape)
mesh_dim = reference_band.mesh_dim()
target_shape[mesh_dim] = self._target_mesh_dim_length
target_array = np.ma.masked_all(target_shape, dtype=reference_band.dtype)
for regrid_band in regrid_bands:
band_number = regrid_band.attributes["band_number"]
indices_in_band = np.nonzero(mesh_mapping.data == band_number)[0]
slice_to_fill = [slice(None)] * target_array.ndim
slice_to_fill[mesh_dim] = indices_in_band
target_array[tuple(slice_to_fill)] = regrid_band.data
dim_coords_and_dims = [
(coord, reference_band.coord_dims(coord))
for coord in reference_band.dim_coords
]
aux_coords_and_dims = [
(coord, reference_band.mesh_dim()) for coord in mesh_mapping.aux_coords
]
result = Cube(
target_array,
dim_coords_and_dims=dim_coords_and_dims,
aux_coords_and_dims=aux_coords_and_dims,
)
result.metadata = reference_band.metadata
result.attributes.pop("band_number")
return result