def convert_to_type(type_like):
"""Converts `type_like` to a `Type`.
If `type_like` is already a `Type`, it is returned. The following
conversions are performed:
* Python tuples become `Tuple`s; items are recursively converted.
* A `tf.TensorShape` becomes a corresponding `TensorType` with
`dtype=float32`. Must be fully defined.
* Lists of `shape + [dtype]` (e.g. `[3, 4, 'int32']`) become
`TensorType`s, with the default `dtype=float32` if omitted.
* A `tf.Dtype` or stringified version thereof (e.g. `'int64'`)
becomes a corresponding scalar `TensorType((), dtype)`.
* An integer `vector_len` becomes a corresponding vector
`TensorType((vector_len,), dtype=float32)`.
Args:
type_like: Described above.
Returns:
A `Type`.
Raises:
TypeError: If `type_like` cannot be converted to a `Type`.
"""
if isinstance(type_like, ResultType):
return type_like
if isinstance(type_like, tf.TensorShape):
# Check this *before* calling as_list() otherwise it throws.
if not type_like.is_fully_defined():
raise TypeError('shape %s is not fully defined' % type_like)
return TensorType(type_like.as_list())
if isinstance(type_like, tuple):
return TupleType(convert_to_type(item) for item in type_like)
if isinstance(type_like, list):
if type_like and isinstance(type_like[-1], six.string_types):
return TensorType(type_like[:-1], dtype=type_like[-1])
else:
return TensorType(type_like)
if isinstance(type_like, tf.DType) or isinstance(type_like, six.string_types):
return TensorType((), dtype=type_like)
if isinstance(type_like, numbers.Integral):
return TensorType((type_like,))
raise TypeError('Cannot covert %s to a type.' % (type_like,))
评论列表
文章目录