def get_shape(tfobj_or_name, graph):
"""
Return the shape of the tensor as a list
:param graph: tf.Graph, a TensorFlow Graph object
:param tfobj_or_name: either a tf.Tensor, tf.Operation or a name to either
"""
graph = validated_graph(graph)
_shape = get_tensor(tfobj_or_name, graph).get_shape().as_list()
return [-1 if x is None else x for x in _shape]
评论列表
文章目录