utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:spark-deep-learning 作者: databricks 项目源码 文件源码
def op_name(tfobj_or_name, graph=None):
    """
    Derive the :py:class:`tf.Operation` name from a :py:class:`tf.Operation` or
    :py:class:`tf.Tensor` object, or its name.
    If a name is provided and the graph is not, we will derive the operation name based on
    TensorFlow's naming convention.
    If the input is a TensorFlow object, or the graph is given, we also check that
    the operation exists in the associated graph.

    :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or
                          a name to either.
    :param graph: a :py:class:`tf.Graph` object containing the operation.
                  By default the graph we don't require this argument to be provided.
    """
    if graph is not None:
        return get_op(tfobj_or_name, graph).name
    if isinstance(tfobj_or_name, six.string_types):
        # If input is a string, assume it is a name and infer the corresponding operation name.
        # WARNING: this depends on TensorFlow's operation naming convention
        name = tfobj_or_name
        name_parts = name.split(":")
        assert len(name_parts) <= 2, name_parts
        return name_parts[0]
    elif hasattr(tfobj_or_name, 'graph'):
        return get_op(tfobj_or_name, tfobj_or_name.graph).name
    else:
        raise TypeError('invalid tf.Operation name query type {}'.format(type(tfobj_or_name)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号