def op_name(tfobj_or_name, graph=None):
"""
Derive the :py:class:`tf.Operation` name from a :py:class:`tf.Operation` or
:py:class:`tf.Tensor` object, or its name.
If a name is provided and the graph is not, we will derive the operation name based on
TensorFlow's naming convention.
If the input is a TensorFlow object, or the graph is given, we also check that
the operation exists in the associated graph.
:param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or
a name to either.
:param graph: a :py:class:`tf.Graph` object containing the operation.
By default the graph we don't require this argument to be provided.
"""
if graph is not None:
return get_op(tfobj_or_name, graph).name
if isinstance(tfobj_or_name, six.string_types):
# If input is a string, assume it is a name and infer the corresponding operation name.
# WARNING: this depends on TensorFlow's operation naming convention
name = tfobj_or_name
name_parts = name.split(":")
assert len(name_parts) <= 2, name_parts
return name_parts[0]
elif hasattr(tfobj_or_name, 'graph'):
return get_op(tfobj_or_name, tfobj_or_name.graph).name
else:
raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name)))
评论列表
文章目录