def get_dtype(dtype):
"""
A helper function to get tf.dtype from str
:param dtype: a str, e.g. "int32"
:return: corresponding tf.dtype
"""
assert isinstance(dtype, str)
if dtype in __str2dtype:
return __str2dtype[dtype]
return tf.int32
评论列表
文章目录