def _valid_input(self, value, dtype=None):
if not misc.is_valid_param_value(value):
msg = 'The value must be either a tensorflow variable, an array or a scalar.'
raise ValueError(msg)
cast = not (dtype is None)
is_built = False
shape = None
if hasattr(self, '_value'): # The parameter has not initialized yet.
is_built = self.is_built_coherence() == Build.YES
shape = self.shape
inner_dtype = self.dtype
if dtype is not None and inner_dtype != dtype:
msg = 'Overriding parameter\'s type "{0}" with "{1}" is not possible.'
raise ValueError(msg.format(inner_dtype, dtype))
elif isinstance(value, np.ndarray) and inner_dtype != value.dtype:
msg = 'The value has different data type "{0}". Parameter type is "{1}".'
raise ValueError(msg.format(value.dtype, inner_dtype))
cast = False
dtype = inner_dtype
if misc.is_number(value):
value_type = np.result_type(value).type
num_type = misc.normalize_num_type(value_type)
dtype = num_type if dtype is None else dtype
value = np.array(value, dtype=dtype)
elif misc.is_list(value):
dtype = settings.float_type if dtype is None else dtype
value = np.array(value, dtype=dtype)
elif cast:
value = value.astype(dtype)
if shape is not None and self.fixed_shape and is_built and shape != value.shape:
msg = 'Value has different shape. Parameter shape {0}, value shape {1}.'
raise ValueError(msg.format(shape, value.shape))
return value
评论列表
文章目录