def dtype(self):
return self._dtype
# Serialize the tf.dtype as a string so that it can be unpickled on DataFlow.
python类DType()的实例源码
def __init__(self, dtype):
super(FloatDomain, self).__init__(dtype)
if not self.dtype.is_floating:
raise ValueError(
'FloatDomain must be initialized with an floating point dtype.')
def vocabulary_file(self):
return self._vocabulary_file
# Serialize the tf.dtype as a string so that it can be unpickled on DataFlow.
def __setstate__(self, state):
self._dtype = tf.as_dtype(state['dtype'])
self._is_categorical = state['is_categorical']
self._min_value = state['min_value']
self._max_value = state['max_value']
self._vocabulary_file = state['vocabulary_file']
def __init__(self, dtype):
super(StringDomain, self).__init__(dtype)
if self.dtype != tf.string:
raise ValueError('StringDomain must be initialized with a string dtype.')
def __init__(self, dtype):
super(BoolDomain, self).__init__(dtype)
if self.dtype != tf.bool:
raise ValueError('BoolDomain must be initialized with a boolean dtype.')
def _dtype_to_domain(dtype):
"""Create an appropriate Domain for the given dtype."""
if dtype.is_integer:
return IntDomain(dtype)
if dtype.is_floating:
return FloatDomain(dtype)
if dtype == tf.string:
return StringDomain(dtype)
if dtype == tf.bool:
return BoolDomain(dtype)
raise ValueError('Schema cannot accommodate dtype: {}'.format(dtype))
def as_feature_spec(self, column):
if not column.is_fixed_size():
raise ValueError('A column of unknown size cannot be represented as '
'fixed-size.')
if column.domain.dtype not in _TF_EXAMPLE_ALLOWED_TYPES:
raise ValueError('tf.Example parser supports only types {}, so it is '
'invalid to generate a feature_spec with type '
'{}.'.format(
_TF_EXAMPLE_ALLOWED_TYPES,
repr(column.domain.dtype)))
return tf.FixedLenFeature(column.tf_shape().as_list(),
column.domain.dtype,
self.default_value)
def as_feature_spec(self, column):
if column.domain.dtype not in _TF_EXAMPLE_ALLOWED_TYPES:
raise ValueError('tf.Example parser supports only types {}, so it is '
'invalid to generate a feature_spec with type '
'{}.'.format(
_TF_EXAMPLE_ALLOWED_TYPES,
repr(column.domain.dtype)))
return tf.VarLenFeature(column.domain.dtype)
def as_batched_placeholder(self, column):
return tf.sparse_placeholder(
column.domain.dtype,
[None] + column.tf_shape().as_list())
def as_batched_placeholder(self, column):
return tf.sparse_placeholder(
column.domain.dtype,
[None] + column.tf_shape().as_list())
def _from_parse_feature(parse_feature):
"""Convert a single feature spec to a ColumnSchema."""
# FixedLenFeature
if isinstance(parse_feature, tf.FixedLenFeature):
representation = FixedColumnRepresentation(parse_feature.default_value)
return ColumnSchema(parse_feature.dtype, parse_feature.shape,
representation)
# FixedLenSequenceFeature
if isinstance(parse_feature, tf.FixedLenSequenceFeature):
raise ValueError('DatasetSchema does not support '
'FixedLenSequenceFeature yet.')
# VarLenFeature
if isinstance(parse_feature, tf.VarLenFeature):
representation = ListColumnRepresentation()
return ColumnSchema(parse_feature.dtype, [None], representation)
# SparseFeature
if isinstance(parse_feature, tf.SparseFeature):
index_field = SparseIndexField(name=parse_feature.index_key,
is_sorted=parse_feature.already_sorted)
representation = SparseColumnRepresentation(
value_field_name=parse_feature.value_key,
index_fields=[index_field])
return ColumnSchema(parse_feature.dtype, [parse_feature.size],
representation)
raise ValueError('Cannot interpret feature spec {} with type {}'.format(
parse_feature, type(parse_feature)))
def assert_valid_dtypes(tensors):
"""Asserts tensors are all valid types (see `_valid_dtypes`).
Args:
tensors: Tensors to check.
Raises:
ValueError: If any tensor is not a valid type.
"""
valid_dtype = valid_dtypes()
for t in tensors:
dtype = t.dtype.base_dtype
if dtype not in valid_dtype:
raise ValueError("Invalid type %r for %s, expected: %s." % (dtype, t.name, [v for v in valid_dtype]))
def constant_value(value_or_tensor_or_var, dtype=None):
"""Returns value if value_or_tensor_or_var has a constant value.
Args:
value_or_tensor_or_var: A value, a `Tensor` or a `Variable`.
dtype: Optional `tf.dtype`, if set it would check it has the right
dtype.
Returns:
The constant value or None if it not constant.
Raises:
ValueError: if value_or_tensor_or_var is None or the tensor_variable has the
wrong dtype.
"""
if value_or_tensor_or_var is None:
raise ValueError('value_or_tensor_or_var cannot be None')
value = value_or_tensor_or_var
if isinstance(value_or_tensor_or_var, (ops.Tensor, variables.Variable)):
if dtype and value_or_tensor_or_var.dtype != dtype:
raise ValueError('It has the wrong type %s instead of %s' % (value_or_tensor_or_var.dtype, dtype))
if isinstance(value_or_tensor_or_var, variables.Variable):
value = None
else:
value = tensor_util.constant_value(value_or_tensor_or_var)
return value
def _bbox_to_mask_fixed_size(yy, region_size, output_size, dtype):
mask = _bbox_to_mask(yy, region_size, dtype)
nonzero_region = tf.greater(tf.reduce_prod(tf.shape(mask)), 0)
mask = tf.cond(nonzero_region, lambda: mask, lambda: tf.zeros(output_size, dtype))
mask = tf.image.resize_images(mask[..., tf.newaxis], output_size)[..., 0]
return mask
def assert_valid_dtypes(tensors):
"""Asserts tensors are all valid types (see `_valid_dtypes`).
Args:
tensors: Tensors to check.
Raises:
ValueError: If any tensor is not a valid type.
"""
valid_dtype = valid_dtypes()
for t in tensors:
dtype = t.dtype.base_dtype
if dtype not in valid_dtype:
raise ValueError("Invalid type %r for %s, expected: %s." %
(dtype, t.name, [v for v in valid_dtype]))
def constant_value(value_or_tensor_or_var, dtype=None):
"""Returns value if value_or_tensor_or_var has a constant value.
Args:
value_or_tensor_or_var: A value, a `Tensor` or a `Variable`.
dtype: Optional `tf.dtype`, if set it would check it has the right
dtype.
Returns:
The constant value or None if it not constant.
Raises:
ValueError: if value_or_tensor_or_var is None or the tensor_variable has the
wrong dtype.
"""
if value_or_tensor_or_var is None:
raise ValueError('value_or_tensor_or_var cannot be None')
value = value_or_tensor_or_var
if isinstance(value_or_tensor_or_var, (ops.Tensor, variables.Variable)):
if dtype and value_or_tensor_or_var.dtype != dtype:
raise ValueError('It has the wrong type %s instead of %s' %
(value_or_tensor_or_var.dtype, dtype))
if isinstance(value_or_tensor_or_var, variables.Variable):
value = None
else:
value = tensor_util.constant_value(value_or_tensor_or_var)
return value
def get_dtype(dtype):
"""
A helper function to get tf.dtype from str
:param dtype: a str, e.g. "int32"
:return: corresponding tf.dtype
"""
assert isinstance(dtype, str)
if dtype in __str2dtype:
return __str2dtype[dtype]
return tf.int32
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='post',
truncating='post', value=0.):
""" pad_sequences.
Pad each sequence to the same length: the length of the longest sequence.
If maxlen is provided, any sequence longer than maxlen is truncated to
maxlen. Truncation happens off either the beginning or the end (default)
of the sequence. Supports pre-padding and post-padding (default).
Args:
sequences: list of lists where each element is a sequence.
maxlen: a `int`, maximum length.
dtype: type to cast the resulting sequence.
padding: 'pre' or 'post', pad either before or after each sequence.
truncating: 'pre' or 'post', remove values from sequences larger than
maxlen either in the beginning or in the end of the sequence
value: `float`, value to pad the sequences to the desired value.
Returns:
x: `numpy array` with dimensions (number_of_sequences, maxlen)
"""
lengths = [len(s) for s in sequences]
nb_samples = len(sequences)
if maxlen is None:
maxlen = np.max(lengths)
x = (np.ones((nb_samples, maxlen)) * value).astype(dtype)
for idx, s in enumerate(sequences):
if len(s) == 0:
continue # empty list was found
if truncating == 'pre':
trunc = s[-maxlen:]
elif truncating == 'post':
trunc = s[:maxlen]
else:
raise ValueError("Truncating type '%s' not understood" % padding)
if padding == 'post':
x[idx, :len(trunc)] = trunc
elif padding == 'pre':
x[idx, -len(trunc):] = trunc
else:
raise ValueError("Padding type '%s' not understood" % padding)
return x