utils.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def assert_same_float_dtype(tensors_with_name, dtype=None):
    """
    Whether all types of tensors in `tensors` are the same and floating type.

    :param tensors_with_name: A list of (tensor, tensor_name).
    :param dtype: Expected type. If `None`, depend on the type of tensors.
    :return: The type of `tensors`.
    """

    floating_types = [tf.float16, tf.float32, tf.float64]
    if dtype is None:
        return assert_same_specific_dtype(tensors_with_name, floating_types)
    elif dtype in floating_types:
        return assert_same_dtype(tensors_with_name, dtype)
    else:
        raise TypeError("The argument 'dtype' must be in %s" % floating_types)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号