def align_func(output_shape, output_dtype):
"""Decorator that ensures the output of ``func`` is an
:class:`~numpy:numpy.ndarray` with the given shape and dtype.
Parameters
----------
output_shape : tuple of int
Desired shape for function output (must have the same size as actual
function output)
output_dtype : ``tf.DType`` or :class:`~numpy:numpy.dtype`
Desired dtype of function output
Raises
------
:class:`~nengo:nengo.exceptions.SimulationError`
If the function returns ``None`` or a non-finite value.
"""
if isinstance(output_dtype, tf.DType):
output_dtype = output_dtype.as_numpy_dtype
def apply_align(func):
def aligned_func(*args):
output = func(*args)
if output is None:
raise SimulationError(
"Function %r returned None" %
function_name(func, sanitize=False))
try:
if not np.all(np.isfinite(output)):
raise SimulationError(
"Function %r returned invalid value %r" %
(function_name(func, sanitize=False), output))
except (TypeError, ValueError):
raise SimulationError(
"Function %r returned a value %r of invalid type %r" %
(function_name(func, sanitize=False), output,
type(output)))
output = np.asarray(output, dtype=output_dtype)
output = output.reshape(output_shape)
return output
return aligned_func
return apply_align
评论列表
文章目录