result_types.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:fold 作者: tensorflow 项目源码 文件源码
def convert_to_type(type_like):
  """Converts `type_like` to a `Type`.

  If `type_like` is already a `Type`, it is returned. The following
  conversions are performed:

  * Python tuples become `Tuple`s; items are recursively converted.

  * A `tf.TensorShape` becomes a corresponding `TensorType` with
  `dtype=float32`. Must be fully defined.

  * Lists of `shape + [dtype]` (e.g. `[3, 4, 'int32']`) become
  `TensorType`s, with the default `dtype=float32` if omitted.

  * A `tf.Dtype` or stringified version thereof (e.g. `'int64'`)
  becomes a corresponding scalar `TensorType((), dtype)`.

  * An integer `vector_len` becomes a corresponding vector
  `TensorType((vector_len,), dtype=float32)`.

  Args:
    type_like: Described above.

  Returns:
    A `Type`.

  Raises:
    TypeError: If `type_like` cannot be converted to a `Type`.

  """
  if isinstance(type_like, ResultType):
    return type_like
  if isinstance(type_like, tf.TensorShape):
    # Check this *before* calling as_list() otherwise it throws.
    if not type_like.is_fully_defined():
      raise TypeError('shape %s is not fully defined' % type_like)
    return TensorType(type_like.as_list())
  if isinstance(type_like, tuple):
    return TupleType(convert_to_type(item) for item in type_like)
  if isinstance(type_like, list):
    if type_like and isinstance(type_like[-1], six.string_types):
      return TensorType(type_like[:-1], dtype=type_like[-1])
    else:
      return TensorType(type_like)
  if isinstance(type_like, tf.DType) or isinstance(type_like, six.string_types):
    return TensorType((), dtype=type_like)
  if isinstance(type_like, numbers.Integral):
    return TensorType((type_like,))
  raise TypeError('Cannot covert %s to a type.' % (type_like,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号