# This code is provided with the hope that it will be useful.
# No guarantee is provided whatsoever. Use at your own risk.
#
# David Kindig and Alex Pletzer, Tech-X Corp. (2012)
"""
ESMF regridding class
"""
import os
import re
import socket
import numpy
import ESMF
from . import esmf
from . import RegridError
from .mvGenericRegrid import GenericRegrid
try:
socket.gethostbyname(socket.gethostname())
except Exception:
os.environ['MPICH_INTERFACE_HOSTNAME'] = 'localhost'
ESMF.Manager(debug=False)
HAVE_MPI = False
try:
from mpi4py import MPI
HAVE_MPI = True
except BaseException:
pass
# constants
CENTER = ESMF.StaggerLoc.CENTER # Same as ESMF_STAGGERLOC_CENTER_VCENTER
CORNER = ESMF.StaggerLoc.CORNER
VCORNER = ESMF.StaggerLoc.CORNER_VFACE
VFACE = VCORNER
CONSERVE = ESMF.RegridMethod.CONSERVE
PATCH = ESMF.RegridMethod.PATCH
BILINEAR = ESMF.RegridMethod.BILINEAR
IGNORE = ESMF.UnmappedAction.IGNORE
ERROR = ESMF.UnmappedAction.ERROR
[docs]class ESMFRegrid(GenericRegrid):
"""
Regrid class for ESMF
Constructor
Parameters
----------
srcGridShape
tuple source grid shape
dstGridShape
tuple destination grid shape
dtype
a valid numpy data type for the src/dst data
regridMethod
'linear', 'conserve', or 'patch'
staggerLoc
the staggering of the field, 'center' or 'corner'
periodicity
0 (no periodicity),
1 (last coordinate is periodic,
2 (both coordinates are periodic)
coordSys
'deg', 'cart', or 'rad'
hasSrcBounds
tuple source bounds shape
hasDstBounds
tuple destination bounds shape
ignoreDegenerate
Ignore degenerate celss when checking inputs
"""
[docs] def __init__(self, srcGridshape, dstGridshape, dtype,
regridMethod, staggerLoc, periodicity, coordSys,
srcGridMask=None, hasSrcBounds=False, srcGridAreas=None,
dstGridMask=None, hasDstBounds=False, dstGridAreas=None,
ignoreDegenerate=False,
**args):
"""
"""
# esmf grid objects (tobe constructed)
self.srcGrid = None
self.dstGrid = None
self.dtype = dtype
self.srcGridShape = srcGridshape
self.dstGridShape = dstGridshape
self.ignoreDegenerate = ignoreDegenerate
self.ndims = len(self.srcGridShape)
self.hasSrcBounds = hasSrcBounds
self.hasDstBounds = hasDstBounds
self.regridMethod = BILINEAR
self.regridMethodStr = 'linear'
if isinstance(regridMethod, str):
if re.search('conserv', regridMethod.lower()):
self.regridMethod = CONSERVE
self.regridMethodStr = 'conserve'
elif re.search('patch', regridMethod.lower()):
self.regridMethod = PATCH
self.regridMethodStr = 'patch'
# data stagger
self.staggerloc = CENTER
self.staggerlocStr = 'center'
if isinstance(staggerLoc, str):
if re.search('vface', staggerLoc.lower(), re.I):
self.staggerloc = VFACE
self.staggerlocStr = 'vcorner'
# there are other staggers we could test here
elif re.search('corner', staggerLoc.lower(), re.I) or \
re.search('node', staggerLoc.lower(), re.I):
self.staggerloc = CORNER
self.staggerlocStr = 'corner'
# there are other staggers we could test here
# good for now
unMappedAction = args.get('unmappedaction', 'ignore')
self.unMappedAction = ESMF.UnmappedAction.IGNORE
if re.search('error', unMappedAction.lower()):
self.unMappedAction = ESMF.UnmappedAction.ERROR
self.coordSys = ESMF.CoordSys.SPH_DEG
self.coordSysStr = 'deg'
if re.search('cart', coordSys.lower()):
self.coordSys = ESMF.CoordSys.CART
self.coordSysStr = 'cart'
elif re.search('rad', coordSys.lower()):
self.coordSys = ESMF.CoordSys.SPH_RAD
self.coordSysStr = 'rad'
self.periodicity = periodicity
# masks can take several values in ESMF, we'll have just one
# value (1) which means invalid
# self.srcMaskValues = numpy.array([1],dtype = numpy.int32)
# self.dstMaskValues = numpy.array([1],dtype = numpy.int32)
if isinstance(regridMethod, str):
if re.search('conserv', regridMethod.lower()):
self.srcMaskValues = numpy.array([1], dtype=numpy.int32)
self.dstMaskValues = numpy.array([1], dtype=numpy.int32)
else:
self.srcMaskValues = srcGridMask
self.dstMaskValues = dstGridMask
# provided by user or None
self.srcGridAreas = srcGridAreas
self.dstGridAreas = dstGridAreas
self.maskPtr = None
# MPI stuff
self.pe = 0
self.nprocs = 1
self.comm = None
if HAVE_MPI:
self.comm = MPI.COMM_WORLD
self.pe = self.comm.Get_rank()
self.nprocs = self.comm.Get_size()
# checks
if self.ndims != len(self.dstGridShape):
msg = """
mvESMFRegrid.ESMFRegrid.__init__: mismatch in the number of topological
dimensions. len(srcGridshape) = %d != len(dstGridshape) = %d""" % \
(self.ndims, len(self.dstGridShape))
raise RegridError(msg)
# Initialize the grids without data.
self.srcGrid = esmf.EsmfStructGrid(self.srcGridShape,
coordSys=self.coordSys,
periodicity=self.periodicity,
staggerloc=self.staggerloc,
hasBounds=self.hasSrcBounds)
self.dstGrid = esmf.EsmfStructGrid(dstGridshape,
coordSys=self.coordSys,
periodicity=self.periodicity,
staggerloc=self.staggerloc,
hasBounds=self.hasDstBounds)
# Initialize the fields with data.
self.srcFld = esmf.EsmfStructField(self.srcGrid, 'srcFld',
datatype=self.dtype,
staggerloc=self.staggerloc)
self.dstFld = esmf.EsmfStructField(self.dstGrid, 'dstFld',
datatype=self.dtype,
staggerloc=self.staggerloc)
self.srcAreaField = esmf.EsmfStructField(self.srcGrid, name='srcAreas',
datatype=self.dtype,
staggerloc=self.staggerloc)
self.dstAreaField = esmf.EsmfStructField(self.dstGrid, name='dstAreas',
datatype=self.dtype,
staggerloc=self.staggerloc)
self.srcFracField = esmf.EsmfStructField(self.srcGrid, name='srcFracAreas',
datatype=self.dtype,
staggerloc=self.staggerloc)
self.dstFracField = esmf.EsmfStructField(self.dstGrid, name='dstFracAreas',
datatype=self.dtype,
staggerloc=self.staggerloc)
self.srcFld.field.data[:] = -1
self.dstFld.field.data[:] = -1
self.srcAreaField.field.data[:] = 0.0
self.dstAreaField.field.data[:] = 0.0
self.srcFracField.field.data[:] = 1.0
self.dstFracField.field.data[:] = 1.0
[docs] def setCoords(self, srcGrid, dstGrid,
srcGridMask=None, srcBounds=None, srcGridAreas=None,
dstGridMask=None, dstBounds=None, dstGridAreas=None,
globalIndexing=False, **args):
"""
Populator of grids, bounds and masks
Parameters
----------
srcGrid : list [[z], y, x] of source grid arrays
dstGrid : list [[z], y, x] of dstination grid arrays
srcGridMask : list [[z], y, x] of arrays
srcBounds : list [[z], y, x] of arrays
srcGridAreas : list [[z], y, x] of arrays
dstGridMask : list [[z], y, x] of array
dstBounds : list [[z], y, x] of arrays
dstGridAreas : list [[z], y, x] of arrays
globalIndexing : if True array was allocated over global index space,
otherwise array was allocated over local index space on
this processor. This is only relevant if rootPe is None
"""
# create esmf source Grid
self.srcGrid.setCoords(srcGrid, staggerloc=self.staggerloc,
globalIndexing=globalIndexing)
if srcGridMask is not None:
self.srcGrid.setMask(srcGridMask, self.staggerloc)
if srcBounds is not None:
# Coords are CENTER (cell) based, bounds are CORNER (nodal)
# VCORNER for 3D
if self.staggerloc != CORNER and self.staggerloc != VCORNER:
if self.ndims == 2:
# cell field, need to provide the bounds
self.srcGrid.setCoords(srcBounds, staggerloc=CORNER,
globalIndexing=globalIndexing)
if self.ndims == 3:
# cell field, need to provide the bounds
self.srcGrid.setCoords(srcBounds, staggerloc=VCORNER,
globalIndexing=globalIndexing)
elif self.staggerloc == CORNER or self.staggerloc == VCORNER:
msg = """
mvESMFRegrid.ESMFRegrid.__init__: can't set the src bounds for
staggerLoc = %s!""" % self.staggerLoc
raise RegridError(msg)
# create destination Grid
self.dstGrid.setCoords(dstGrid, staggerloc=self.staggerloc,
globalIndexing=globalIndexing)
if dstGridMask is not None:
self.dstGrid.setMask(dstGridMask)
if dstBounds is not None:
# Coords are CENTER (cell) based, bounds are CORNER (nodal)
if self.staggerloc != CORNER and self.staggerloc != VCORNER:
if self.ndims == 2:
self.dstGrid.setCoords(dstBounds, staggerloc=CORNER,
globalIndexing=globalIndexing)
if self.ndims == 3:
self.dstGrid.setCoords(dstBounds, staggerloc=VCORNER,
globalIndexing=globalIndexing)
elif self.staggerloc == CORNER or self.staggerloc == VCORNER:
msg = """
mvESMFRegrid.ESMFRegrid.__init__: can't set the dst bounds for
staggerLoc = %s!""" % self.staggerLoc
raise RegridError(msg)
[docs] def computeWeights(self, **args):
"""
Compute interpolation weights
Parameters
----------
args : (not used)
"""
self.regridObj = ESMF.Regrid(srcfield=self.srcFld.field,
dstfield=self.dstFld.field,
src_mask_values=self.srcMaskValues,
dst_mask_values=self.dstMaskValues,
regrid_method=self.regridMethod,
unmapped_action=self.unMappedAction,
ignore_degenerate=True)
[docs] def apply(self, srcData, dstData, rootPe, globalIndexing=False, **args):
"""
Regrid source to destination.
When used in parallel, if the processor is not the root processor,
the dstData returns None.
Source data mask:
- If you provide srcDataMask in args the source grid will be
masked and weights will be recomputed.
- Subsequently, if you do not provide a srcDataMask the last weights will
be used to regrid the source data array.
- By default, only the data are masked, but not the grid.
Parameters
----------
srcData : array source data, shape should cover entire global index space
dstData : array destination data, shape should cover entire global index space
rootPe : if other than None, then data will be MPI gathered on the specified rootPe processor
globalIndexing : if True array was allocated over global index space, otherwise array was
allocated over local index space on this processor. This is only relevant
if rootPe is None
args
"""
# if args.has_key('srcDataMask'):
# srcDataMask=args.get('srcDataMask')
# Make sure with have a mask intialized for this grid.
# if(self.maskPtr is None):
# if(self.srcFld.field.grid.mask[self.staggerloc] is None):
# self.srcFld.field.grid.add_item(item=ESMF.GridItem.MASK, staggerloc=self.staggerloc)
# self.maskPtr = self.srcFld.field.grid.get_item(item=ESMF.GridItem.MASK,
# staggerloc=self.staggerloc)
# Recompute weights only if masks are different.
# if(not numpy.array_equal(self.maskPtr, srcDataMask)):
# self.maskPtr[:] = srcDataMask[:]
# self.computeWeights(**args)
zero_region = ESMF.Region.SELECT
if 'zero_region' in args.keys():
zero_region = args.get('zero_region')
self.srcFld.field.data[:] = srcData.T
self.dstFld.field.data[:] = dstData.T
# regrid
self.regridObj(
self.srcFld.field,
self.dstFld.field,
zero_region=zero_region)
# fill in dstData
if rootPe is None and globalIndexing:
# only fill in the relevant portion of the data
slab = self.dstGrid.getLocalSlab(staggerloc=self.staggerloc)
dstData[slab] = self.dstFld.getData(rootPe=rootPe)
else:
tmp = self.dstFld.field.data.T
if tmp is None:
dstData = None
else:
dstData[:] = tmp
[docs] def getDstGrid(self):
"""
Get the destination grid on this processor
Returns
-------
grid
"""
return [self.dstGrid.getCoords(i, staggerloc=self.staggerloc)
for i in range(self.ndims)]
[docs] def getSrcAreas(self, rootPe):
"""
Get the source grid cell areas
Parameters
----------
rootPe : root processor where data should be gathered (or None if local areas are to be returned)
Returns
-------
areas or None if non-conservative interpolation
"""
if self.regridMethod == CONSERVE:
# self.srcAreaField.field.get_area()
return self.srcAreaField.field.data
else:
return None
[docs] def getDstAreas(self, rootPe):
"""
Get the destination grid cell areas
Parameters
----------
rootPe : root processor where data should be gathered (or None if local areas are to be returned)
Returns
-------
areas or None if non-conservative interpolation
"""
if self.regridMethod == CONSERVE:
# self.dstAreaField.field.get_area()
return self.dstAreaField.field.data
else:
return None
[docs] def getSrcAreaFractions(self, rootPe):
"""
Get the source grid area fractions
Parameters
----------
rootPe : root processor where data should be gathered (or None if local areas are to be returned)
Returns
-------
fractional areas or None (if non-conservative)
"""
if self.regridMethod == CONSERVE:
return self.srcFracField.field.data
else:
return None
[docs] def getDstAreaFractions(self, rootPe):
"""
Get the destination grid area fractions
Parameters
----------
rootPe : root processor where data should be gathered (or None if local areas are to be returned)
Returns
-------
fractional areas or None (if non-conservative)
"""
if self.regridMethod == CONSERVE:
return self.dstFracField.field.data
else:
return
[docs] def getSrcLocalShape(self, staggerLoc):
"""
Get the local source coordinate/data shape (may be different on each processor)
Parameters
----------
staggerLoc : (e.g. 'center' or 'corner')
Returns
-------
tuple
"""
stgloc = CENTER
if re.match('corner', staggerLoc, re.I) or \
re.search('nod', staggerLoc, re.I):
stgloc = CORNER
elif re.search('vface', staggerLoc, re.I) or \
re.search('vcorner', staggerLoc, re.I):
stgloc = VFACE
return self.srcGrid.getCoordShape(stgloc)
[docs] def getDstLocalShape(self, staggerLoc):
"""
Get the local destination coordinate/data shape (may be different on each processor)
Parameters
----------
staggerLoc : (e.g. 'center' or 'corner')
Returns
-------
tuple
"""
stgloc = CENTER
if re.match('corner', staggerLoc, re.I) or \
re.search('nod', staggerLoc, re.I):
stgloc = CORNER
elif re.search('vface', staggerLoc, re.I) or \
re.search('vcorner', staggerLoc, re.I):
stgloc = VFACE
return self.dstGrid.getCoordShape(stgloc)
[docs] def getSrcLocalSlab(self, staggerLoc):
"""
Get the destination local slab (ellipsis). You can use this to grab the data
local to this processor
Parameters
----------
staggerLoc : (e.g. 'center'):
Returns
-------
tuple of slices
"""
stgloc = CENTER
if re.match('corner', staggerLoc, re.I) or \
re.search('nod', staggerLoc, re.I):
stgloc = CORNER
elif re.search('vface', staggerLoc, re.I) or \
re.search('vcorner', staggerLoc, re.I):
stgloc = VFACE
return self.srcGrid.getLocalSlab(stgloc)
[docs] def getDstLocalSlab(self, staggerLoc):
"""
Get the destination local slab (ellipsis). You can use this to grab the data local to this
processor
Parameters
----------
staggerLoc : (e.g. 'center')
Returns
-------
tuple of slices
"""
stgloc = CENTER
if re.match('corner', staggerLoc, re.I) or \
re.search('nod', staggerLoc, re.I):
stgloc = CORNER
elif re.search('vface', staggerLoc, re.I) or \
re.search('vcorner', staggerLoc, re.I):
stgloc = VFACE
return self.dstGrid.getLocalSlab(stgloc)
[docs] def fillInDiagnosticData(self, diag, rootPe):
"""
Fill in diagnostic data
Parameters
----------
diag : a dictionary whose entries, if present, will be filled valid
entries are: 'srcAreaFractions', 'dstAreaFractions', srcAreas', 'dstAreas'
rootPe : root processor where data should be gathered (or None if local areas are to be returned)
"""
oldMethods = {}
oldMethods['srcAreaFractions'] = 'getSrcAreaFractions'
oldMethods['dstAreaFractions'] = 'getDstAreaFractions'
oldMethods['srcAreas'] = 'getSrcAreas'
oldMethods['dstAreas'] = 'getDstAreas'
for entry in 'srcAreaFractions', 'dstAreaFractions', \
'srcAreas', 'dstAreas':
if entry in diag:
diag[entry] = eval(
'self.' + oldMethods[entry] + '(rootPe=rootPe)').T
diag['regridTool'] = 'esmf'
diag['regridMethod'] = self.regridMethodStr
diag['periodicity'] = self.periodicity
diag['coordSys'] = self.coordSysStr
diag['staggerLoc'] = self.staggerlocStr