utils.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:spark-deep-learning 作者: databricks 项目源码 文件源码
def get_tensor(tfobj_or_name, graph):
    """
    Get a :py:class:`tf.Tensor` object

    :param tfobj_or_name: either a :py:class:`tf.Tensor`, :py:class:`tf.Operation` or
                          a name to either.
    :param graph: a :py:class:`tf.Graph` object containing the tensor.
                  By default the graph we don't require this argument to be provided.
    """
    graph = validated_graph(graph)
    _assert_same_graph(tfobj_or_name, graph)
    if isinstance(tfobj_or_name, tf.Tensor):
        return tfobj_or_name
    name = tfobj_or_name
    if isinstance(tfobj_or_name, tf.Operation):
        name = tfobj_or_name.name
    if not isinstance(name, six.string_types):
        raise TypeError('invalid tensor request for {} of {}'.format(name, type(name)))
    _tensor_name = tensor_name(name, graph=None)
    tnsr = graph.get_tensor_by_name(_tensor_name)
    err_msg = 'cannot locate tensor {} in the current graph, got [type {}] {}'
    assert isinstance(tnsr, tf.Tensor), err_msg.format(_tensor_name, type(tnsr), tnsr)
    return tnsr
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号