def tf_obj_shape(input):
"""
Convert tf objects to shape tuple.
Arguments:
input: tf.TensorShape, tf.Tensor, tf.AttrValue or tf.NodeDef
the corresponding tensorflow object
Returns:
tuple: shape of the tensorflow object
"""
if isinstance(input, tf.TensorShape):
return tuple([int(i.value) for i in input])
elif isinstance(input, tf.Tensor):
return tf_obj_shape(input.get_shape())
elif isinstance(input, tf.AttrValue):
return tuple([int(d.size) for d in input.shape.dim])
elif isinstance(input, tf.NodeDef):
return tf_obj_shape(input.attr['shape'])
else:
raise TypeError("Input to `tf_obj_shape` has the wrong type.")
评论列表
文章目录