def get_tensor(tfobj_or_name, graph):
"""
Get a :py:class:`tf.Tensor` object
:param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or
a name to either.
:param graph: a :py:class:`tf.Graph` object containing the tensor.
By default the graph we don't require this argument to be provided.
"""
graph = validated_graph(graph)
_assert_same_graph(tfobj_or_name, graph)
if isinstance(tfobj_or_name, tf.Tensor):
return tfobj_or_name
name = tfobj_or_name
if isinstance(tfobj_or_name, tf.Operation):
name = tfobj_or_name.name
if not isinstance(name, six.string_types):
raise TypeError('invalid tensor request for {} of {}'.format(name, type(name)))
_tensor_name = tensor_name(name, graph=None)
tnsr = graph.get_tensor_by_name(_tensor_name)
err_msg = 'cannot locate tensor {} in the current graph, got [type {}] {}'
assert isinstance(tnsr, tf.Tensor), err_msg.format(_tensor_name, type(tnsr), tnsr)
return tnsr
评论列表
文章目录