def test_snapshot_to_xarray_variable(self, ds_model_interface):
ds_model_interface.init_snapshots()
ds_model_interface.set_model_inputs(ds_model_interface.dataset)
ds_model_interface.model.initialize()
ds_model_interface.take_snapshots(0)
expected = xr.Variable('x', np.zeros(10),
{'description': 'a quantity'})
actual = ds_model_interface.snapshot_to_xarray_variable(
('quantity', 'quantity'), clock='clock')
xr.testing.assert_identical(actual, expected)
ds_model_interface.take_snapshots(-1)
expected = xr.Variable(('clock', 'x'), np.zeros((2, 10)))
actual = ds_model_interface.snapshot_to_xarray_variable(
('quantity', 'quantity'), clock='clock')
xr.testing.assert_equal(actual, expected)
expected = xr.Variable('x', np.arange(10))
actual = ds_model_interface.snapshot_to_xarray_variable(('grid', 'x'))
xr.testing.assert_equal(actual, expected)
python类Variable()的实例源码
def snapshot_to_xarray_variable(self, key, clock=None):
"""Convert snapshots taken for a specific model variable to an
xarray.Variable object.
"""
proc_name, var_name = key
variable = self.model._processes[proc_name]._variables[var_name]
array_list = self.snapshot_values[key]
first_array = array_list[0]
if len(array_list) == 1:
data = first_array
else:
data = np.stack(array_list)
dims = _get_dims_from_variable(first_array, variable)
if clock is not None and len(array_list) > 1:
dims = (clock,) + dims
attrs = variable.attrs.copy()
attrs['description'] = variable.description
return xr.Variable(dims, data, attrs=attrs)
def _load_GeoTransform(self):
"""Calculate latitude and longitude variable calculated from the
gdal.Open.GetGeoTransform method"""
def load_lon():
return arange(ds.RasterXSize)*b[1]+b[0]
def load_lat():
return arange(ds.RasterYSize)*b[5]+b[3]
ds = self.ds
b = self.ds.GetGeoTransform() # bbox, interval
if with_dask:
lat = Array(
{('lat', 0): (load_lat,)}, 'lat', (self.ds.RasterYSize,),
shape=(self.ds.RasterYSize,), dtype=float)
lon = Array(
{('lon', 0): (load_lon,)}, 'lon', (self.ds.RasterXSize,),
shape=(self.ds.RasterXSize,), dtype=float)
else:
lat = load_lat()
lon = load_lon()
return Variable(('lat',), lat), Variable(('lon',), lon)
def can_decode(cls, ds, var):
"""
Class method to determine whether the object can be decoded by this
decoder class.
Parameters
----------
ds: xarray.Dataset
The dataset that contains the given `var`
var: xarray.Variable or xarray.DataArray
The array to decode
Returns
-------
bool
True if the decoder can decode the given array `var`. Otherwise
False
Notes
-----
The default implementation returns True for any argument. Subclass this
method to be specific on what type of data your decoder can decode
"""
return True
def standardize_dims(self, var, dims={}):
"""Replace the coordinate names through x, y, z and t
Parameters
----------
var: xarray.Variable
The variable to use the dimensions of
dims: dict
The dictionary to use for replacing the original dimensions
Returns
-------
dict
The dictionary with replaced dimensions"""
dims = dict(dims)
name_map = {self.get_xname(var, self.ds.coords): 'x',
self.get_yname(var, self.ds.coords): 'y',
self.get_zname(var, self.ds.coords): 'z',
self.get_tname(var, self.ds.coords): 't'}
dims = dict(dims)
for dim in set(dims).intersection(name_map):
dims[name_map[dim]] = dims.pop(dim)
return dims
def get_mesh(self, var, coords=None):
"""Get the mesh variable for the given `var`
Parameters
----------
var: xarray.Variable
The data source whith the ``'mesh'`` attribute
coords: dict
The coordinates to use. If None, the coordinates of the dataset of
this decoder is used
Returns
-------
xarray.Coordinate
The mesh coordinate"""
mesh = var.attrs.get('mesh')
if mesh is None:
return None
if coords is None:
coords = self.ds.coords
return coords.get(mesh, self.ds.coords.get(mesh))
def test_plot_bounds_2d(self):
x = np.arange(1, 5)
y = np.arange(5, 10)
x2d, y2d = np.meshgrid(x, y)
x_bnds = np.arange(0.5, 4.51, 1.0)
y_bnds = np.arange(4.5, 9.51, 1.0)
# the borders are not modified
x_bnds[0] = 1.0
x_bnds[-1] = 4.0
y_bnds[0] = 5.0
y_bnds[-1] = 9.0
x2d_bnds, y2d_bnds = np.meshgrid(x_bnds, y_bnds)
d = psyd.CFDecoder()
# test x bounds
bounds = d.get_plotbounds(xr.Variable(('y', 'x'), x2d))
self.assertAlmostArrayEqual(bounds, x2d_bnds)
# test y bounds
bounds = d.get_plotbounds(xr.Variable(('y', 'x'), y2d))
self.assertAlmostArrayEqual(bounds, y2d_bnds)
def _from_dataset_test_variables(self):
"""The variables and coords needed for the from_dataset tests"""
variables = {
# 3d-variable
'v0': xr.Variable(('time', 'ydim', 'xdim'), np.zeros((4, 4, 4))),
# 2d-variable with time and x
'v1': xr.Variable(('time', 'xdim', ), np.zeros((4, 4))),
# 2d-variable with y and x
'v2': xr.Variable(('ydim', 'xdim', ), np.zeros((4, 4))),
# 1d-variable
'v3': xr.Variable(('xdim', ), np.zeros(4))}
coords = {
'ydim': xr.Variable(('ydim', ), np.arange(1, 5)),
'xdim': xr.Variable(('xdim', ), np.arange(4)),
'time': xr.Variable(
('time', ),
pd.date_range('1999-01-01', '1999-05-01', freq='M').values)}
return variables, coords
def read_netcdf(data_handle, domain=None, iter_dims=['lat', 'lon'],
start=None, stop=None, calendar='standard',
var_dict=None) -> xr.Dataset:
"""Read in a NetCDF file"""
ds = xr.open_dataset(data_handle)
if var_dict is not None:
ds.rename(var_dict, inplace=True)
if start is not None and stop is not None:
ds = ds.sel(time=slice(start, stop))
dates = ds.indexes['time']
ds['day_of_year'] = xr.Variable(('time', ), dates.dayofyear)
if domain is not None:
ds = ds.sel(**{d: domain[d] for d in iter_dims})
out = ds.load()
ds.close()
return out
def read_data(data_handle, domain=None, iter_dims=['lat', 'lon'],
start=None, stop=None, calendar='standard',
var_dict=None) -> xr.Dataset:
"""Read data directly from an xarray dataset"""
varlist = list(data_handle.keys())
if var_dict is not None:
data_handle.rename(var_dict, inplace=True)
varlist = list(var_dict.values())
if start is not None and stop is not None:
data_handle = data_handle[varlist].sel(time=slice(start, stop))
dates = data_handle.indexes['time']
data_handle['day_of_year'] = xr.Variable(('time', ), dates.dayofyear)
if domain is not None:
data_handle = data_handle.sel(**{d: domain[d] for d in iter_dims})
out = data_handle.load()
data_handle.close()
return out
def test_constructor(self):
# verify allowed_dims
for allowed_dims in (tuple(), list(), ''):
var = Variable(allowed_dims)
assert var.allowed_dims == ((),)
for allowed_dims in ('x', ['x'], ('x')):
var = Variable(allowed_dims)
assert var.allowed_dims == (('x',),)
var = Variable(('x', 'y'))
assert var.allowed_dims == (('x', 'y'),)
var = Variable([(), 'x', ('x', 'y')])
assert var.allowed_dims == ((), ('x',), ('x', 'y'))
def test_validators(self):
# verify default validators + user supplied validators
validator_func = lambda xr_var: xr_var is not None
class MyVariable(Variable):
default_validators = [validator_func]
var = MyVariable((), validators=[validator_func])
assert var.validators == [validator_func, validator_func]
def test_validate_dimensions(self):
var = Variable([(), 'x', ('x', 'y')])
with pytest.raises(ValidationError) as excinfo:
var.validate_dimensions(('x', 'z'))
assert 'invalid dimensions' in str(excinfo.value)
var.validate_dimensions(('time', 'x'), ignore_dims=['time'])
def test_to_xarray_variable(self):
attrs = {'units': 'm'}
description = 'x var'
xr_var_attrs = attrs.copy()
xr_var_attrs.update({'description': description})
var = Variable('x', description=description, attrs=attrs)
xr_var = var.to_xarray_variable(('x', [1, 2]))
expected_xr_var = xr.Variable('x', data=[1, 2], attrs=xr_var_attrs)
xr.testing.assert_identical(xr_var, expected_xr_var)
var = Variable((), default_value=1)
xr_var = var.to_xarray_variable(2)
expected_xr_var = xr.Variable((), data=2)
xr.testing.assert_identical(xr_var, expected_xr_var)
# test default value
xr_var = var.to_xarray_variable(None)
expected_xr_var = xr.Variable((), data=1)
xr.testing.assert_identical(xr_var, expected_xr_var)
# test variable name
xr_var = var.to_xarray_variable([1, 2])
expected_xr_var = xr.Variable('this_variable', data=[1, 2])
expected_xr_var = expected_xr_var.to_index_variable()
xr.testing.assert_identical(xr_var, expected_xr_var)
def test_constructor(self):
var_list = VariableList([Variable(()), Variable(('x'))])
assert isinstance(var_list, tuple)
with pytest.raises(ValueError) as excinfo:
_ = VariableList([2, Variable(())])
assert "found variables mixed" in str(excinfo.value)
def test_validators(self):
var = FloatVariable(())
for val in [1, 1.]:
xr_var = xr.Variable((), val)
var.run_validators(xr_var)
xr_var = xr.Variable((), '1')
with pytest.raises(ValidationError) as excinfo:
var.run_validators(xr_var)
assert "invalid dtype" in str(excinfo.value)
def test_validators(self):
var = IntegerVariable(())
xr_var = xr.Variable((), 1)
var.run_validators(xr_var)
for val in [1., '1']:
xr_var = xr.Variable((), val)
with pytest.raises(ValidationError) as excinfo:
var.run_validators(xr_var)
assert "invalid dtype" in str(excinfo.value)
def _get_dims_from_variable(array, variable):
"""Given an array of values (snapshot) and a (xarray-simlab) Variable
object, Return dimension labels for the array."""
for dims in variable.allowed_dims:
if len(dims) == array.ndim:
return dims
return tuple()
def get_variables(self):
def load(band):
band = ds.GetRasterBand(band)
a = band.ReadAsArray()
no_data = band.GetNoDataValue()
if no_data is not None:
try:
a[a == no_data] = a.dtype.type(nan)
except ValueError:
pass
return a
ds = self.ds
dims = ['lat', 'lon']
chunks = ((ds.RasterYSize,), (ds.RasterXSize,))
shape = (ds.RasterYSize, ds.RasterXSize)
variables = OrderedDict()
for iband in range(1, ds.RasterCount+1):
band = ds.GetRasterBand(iband)
dt = dtype(gdal_array.codes[band.DataType])
if with_dask:
dsk = {('x', 0, 0): (load, iband)}
arr = Array(dsk, 'x', chunks, shape=shape, dtype=dt)
else:
arr = load(iband)
attrs = band.GetMetadata_Dict()
try:
dt.type(nan)
attrs['_FillValue'] = nan
except ValueError:
no_data = band.GetNoDataValue()
attrs.update({'_FillValue': no_data} if no_data else {})
variables['Band%i' % iband] = Variable(dims, arr, attrs)
variables['lat'], variables['lon'] = self._load_GeoTransform()
return FrozenOrderedDict(variables)
def get_index_from_coord(coord, base_index):
"""Function to return the coordinate as integer, integer array or slice
If `coord` is zero-dimensional, the corresponding integer in `base_index`
will be supplied. Otherwise it is first tried to return a slice, if that
does not work an integer array with the corresponding indices is returned.
Parameters
----------
coord: xarray.Coordinate or xarray.Variable
Coordinate to convert
base_index: pandas.Index
The base index from which the `coord` was extracted
Returns
-------
int, array of ints or slice
The indexer that can be used to access the `coord` in the
`base_index`
"""
try:
values = coord.values
except AttributeError:
values = coord
if values.ndim == 0:
return base_index.get_loc(values[()])
if len(values) == len(base_index) and (values == base_index).all():
return slice(None)
values = np.array(list(map(lambda i: base_index.get_loc(i), values)))
return to_slice(values) or values
#: mapping that translates datetime format strings to regex patterns
def to_netcdf(ds, *args, **kwargs):
"""
Store the given dataset as a netCDF file
This functions works essentially the same as the usual
:meth:`xarray.Dataset.to_netcdf` method but can also encode absolute time
units
Parameters
----------
ds: xarray.Dataset
The dataset to store
%(xarray.Dataset.to_netcdf.parameters)s
"""
to_update = {}
for v, obj in six.iteritems(ds.variables):
units = obj.attrs.get('units', obj.encoding.get('units', None))
if units == 'day as %Y%m%d.%f' and np.issubdtype(
obj.dtype, np.datetime64):
to_update[v] = xr.Variable(
obj.dims, AbsoluteTimeEncoder(obj), attrs=obj.attrs.copy(),
encoding=obj.encoding)
to_update[v].attrs['units'] = units
if to_update:
ds = ds.update(to_update, inplace=False)
return xarray_api.to_netcdf(ds, *args, **kwargs)
def get_x(self, var, coords=None):
"""
Get the x-coordinate of a variable
This method searches for the x-coordinate in the :attr:`ds`. It first
checks whether there is one dimension that holds an ``'axis'``
attribute with 'X', otherwise it looks whether there is an intersection
between the :attr:`x` attribute and the variables dimensions, otherwise
it returns the coordinate corresponding to the last dimension of `var`
Possible types
--------------
var: xarray.Variable
The variable to get the x-coordinate for
coords: dict
Coordinates to use. If None, the coordinates of the dataset in the
:attr:`ds` attribute are used.
Returns
-------
xarray.Coordinate or None
The y-coordinate or None if it could be found"""
coords = coords or self.ds.coords
coord = self.get_variable_by_axis(var, 'x', coords)
if coord is not None:
return coord
return coords.get(self.get_xname(var))
def get_y(self, var, coords=None):
"""
Get the y-coordinate of a variable
This method searches for the y-coordinate in the :attr:`ds`. It first
checks whether there is one dimension that holds an ``'axis'``
attribute with 'Y', otherwise it looks whether there is an intersection
between the :attr:`y` attribute and the variables dimensions, otherwise
it returns the coordinate corresponding to the second last dimension of
`var` (or the last if the dimension of var is one-dimensional)
Possible types
--------------
var: xarray.Variable
The variable to get the y-coordinate for
coords: dict
Coordinates to use. If None, the coordinates of the dataset in the
:attr:`ds` attribute are used.
Returns
-------
xarray.Coordinate or None
The y-coordinate or None if it could be found"""
coords = coords or self.ds.coords
coord = self.get_variable_by_axis(var, 'y', coords)
if coord is not None:
return coord
return coords.get(self.get_yname(var))
def get_z(self, var, coords=None):
"""
Get the vertical (z-) coordinate of a variable
This method searches for the z-coordinate in the :attr:`ds`. It first
checks whether there is one dimension that holds an ``'axis'``
attribute with 'Z', otherwise it looks whether there is an intersection
between the :attr:`z` attribute and the variables dimensions, otherwise
it returns the coordinate corresponding to the third last dimension of
`var` (or the second last or last if var is two or one-dimensional)
Possible types
--------------
var: xarray.Variable
The variable to get the z-coordinate for
coords: dict
Coordinates to use. If None, the coordinates of the dataset in the
:attr:`ds` attribute are used.
Returns
-------
xarray.Coordinate or None
The z-coordinate or None if no z coordinate could be found"""
coords = coords or self.ds.coords
coord = self.get_variable_by_axis(var, 'z', coords)
if coord is not None:
return coord
zname = self.get_zname(var)
if zname is not None:
return coords.get(zname)
return None
def get_t(self, var, coords=None):
"""
Get the time coordinate of a variable
This method searches for the time coordinate in the :attr:`ds`. It
first checks whether there is one dimension that holds an ``'axis'``
attribute with 'T', otherwise it looks whether there is an intersection
between the :attr:`t` attribute and the variables dimensions, otherwise
it returns the coordinate corresponding to the first dimension of `var`
Possible types
--------------
var: xarray.Variable
The variable to get the time coordinate for
coords: dict
Coordinates to use. If None, the coordinates of the dataset in the
:attr:`ds` attribute are used.
Returns
-------
xarray.Coordinate or None
The time coordinate or None if no time coordinate could be found"""
coords = coords or self.ds.coords
coord = self.get_variable_by_axis(var, 't', coords)
if coord is not None:
return coord
dimlist = list(self.t.intersection(var.dims).intersection(coords))
if dimlist:
if len(dimlist) > 1:
warn("Found multiple matches for time coordinate in the "
"variable: %s. I use %s" % (
', '.join(dimlist), dimlist[0]),
PsyPlotRuntimeWarning)
return coords[dimlist[0]]
tname = self.get_tname(var)
if tname is not None:
return coords.get(tname)
return None
def correct_dims(self, var, dims={}, remove=True):
"""Expands the dimensions to match the dims in the variable
Parameters
----------
var: xarray.Variable
The variable to get the data for
dims: dict
a mapping from dimension to the slices
remove: bool
If True, dimensions in `dims` that are not in the dimensions of
`var` are removed"""
method_mapping = {'x': self.get_xname,
'z': self.get_zname, 't': self.get_tname}
dims = dict(dims)
if self.is_unstructured(var): # we assume a one-dimensional grid
method_mapping['y'] = self.get_xname
else:
method_mapping['y'] = self.get_yname
for key in six.iterkeys(dims.copy()):
if key in method_mapping and key not in var.dims:
dim_name = method_mapping[key](var, self.ds.coords)
if dim_name in dims:
dims.pop(key)
else:
new_name = method_mapping[key](var)
if new_name is not None:
dims[new_name] = dims.pop(key)
# now remove the unnecessary dimensions
if remove:
for key in set(dims).difference(var.dims):
dims.pop(key)
self.logger.debug(
"Could not find a dimensions matching %s in variable %s!",
key, var)
return dims
def test_1D_cf_bounds(self):
"""Test whether the CF Conventions for 1D bounaries are correct"""
final_bounds = np.arange(-180, 181, 30)
lon = xr.Variable(('lon', ), np.arange(-165, 166, 30),
{'bounds': 'lon_bounds'})
cf_bounds = xr.Variable(('lon', 'bnds'), np.zeros((len(lon), 2)))
for i in range(len(lon)):
cf_bounds[i, :] = final_bounds[i:i+2]
ds = xr.Dataset(coords={'lon': lon, 'lon_bounds': cf_bounds})
decoder = psyd.CFDecoder(ds)
self.assertEqual(list(final_bounds),
list(decoder.get_plotbounds(lon)))
def test_1D_bounds_calculation(self):
"""Test whether the 1D cell boundaries are calculated correctly"""
final_bounds = np.arange(-180, 181, 30)
lon = xr.Variable(('lon', ), np.arange(-165, 166, 30))
ds = xr.Dataset(coords={'lon': lon})
decoder = psyd.CFDecoder(ds)
self.assertEqual(list(final_bounds),
list(decoder.get_plotbounds(lon)))
def _filter_test_ds(self):
return xr.Dataset(
{'v0': xr.Variable(('ydim', 'xdim'), np.zeros((4, 4)),
attrs={'test': 1, 'test2': 1}),
'v1': xr.Variable(('xdim', ), np.zeros(4), attrs={'test': 2,
'test2': 2}),
'v2': xr.Variable(('xdim', ), np.zeros(4), attrs={'test': 3,
'test2': 3})},
{'ydim': xr.Variable(('ydim', ), np.arange(1, 5)),
'xdim': xr.Variable(('xdim', ), np.arange(4))})
def _test_ds(self):
import xarray as xr
import pandas as pd
time = xr.Coordinate('time', pd.to_datetime(
['1979-01-01T12:00:00', '1979-01-01T18:00:00',
'1979-01-01T18:30:00']),
encoding={'units': 'day as %Y%m%d.%f'})
var = xr.Variable(('time', 'x'), np.zeros((len(time), 5)))
return xr.Dataset({'test': var}, {'time': time})