core.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tflearn 作者: tflearn 项目源码 文件源码
def flatten(incoming, name="Flatten"):
    """ Flatten.

    Flatten the incoming Tensor.

    Input:
        (2+)-D `Tensor`.

    Output:
        2-D `Tensor` [batch, flatten_dims].

    Arguments:
        incoming: `Tensor`. The incoming tensor.

    """
    input_shape = utils.get_incoming_shape(incoming)
    assert len(input_shape) > 1, "Incoming Tensor shape must be at least 2-D"
    dims = int(np.prod(input_shape[1:]))
    x = reshape(incoming, [-1, dims], name)

    # Track output tensor.
    tf.add_to_collection(tf.GraphKeys.LAYER_TENSOR + '/' + name, x)

    return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号