def assert_valid_dtypes(tensors):
"""Asserts tensors are all valid types (see `_valid_dtypes`).
Args:
tensors: Tensors to check.
Raises:
ValueError: If any tensor is not a valid type.
"""
valid_dtype = valid_dtypes()
for t in tensors:
dtype = t.dtype.base_dtype
if dtype not in valid_dtype:
raise ValueError("Invalid type %r for %s, expected: %s." %
(dtype, t.name, [v for v in valid_dtype]))
评论列表
文章目录