util.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def assert_valid_dtypes(tensors):
    """Asserts tensors are all valid types (see `_valid_dtypes`).
    Args:
        tensors: Tensors to check.
    Raises:
        ValueError: If any tensor is not a valid type.
    """
    valid_dtype = valid_dtypes()
    for t in tensors:
        dtype = t.dtype.base_dtype
        if dtype not in valid_dtype:
            raise ValueError("Invalid type %r for %s, expected: %s." %
                             (dtype, t.name, [v for v in valid_dtype]))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号