tensor_node.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:nengo_dl 作者: nengo 项目源码 文件源码
def tensor_layer(input, layer_func, shape_in=None, synapse=None,
                 transform=1, return_conn=False, **layer_args):
    """A utility function to construct TensorNodes that apply some function
    to their input (analogous to the ``tf.layers`` syntax).

    Parameters
    ----------
    input : :class:`~nengo:nengo.base.NengoObject`
        Object providing input to the layer
    layer_func : callable or :class:`~nengo:nengo.neurons.NeuronType`
        A function that takes the value from ``input`` (represented as a
        ``tf.Tensor``) and maps it to some output value, or a Nengo neuron
        type, defining a nonlinearity that will be applied to ``input``.
    shape_in : tuple of int, optional
        If not None, reshape the input to the given shape
    synapse : float or :class:`~nengo:nengo.synapses.Synapse`, optional
        Synapse to apply on connection from ``input`` to this layer
    transform : :class:`~numpy:numpy.ndarray`, optional
        Transform matrix to apply on connection from ``input`` to this layer
    return_conn : bool, optional
        If True, also return the connection linking this layer to ``input``
    layer_args : dict, optional
        These arguments will be passed to ``layer_func`` if it is callable, or
        :class:`~nengo:nengo.Ensemble` if ``layer_func`` is a
        :class:`~nengo:nengo.neurons.NeuronType`

    Returns
    -------
    :class:`.TensorNode` or :class:`~nengo:nengo.ensemble.Neurons`
        A TensorNode that implements the given layer function (if
        ``layer_func`` was a callable), or a Neuron object with the given
        neuron type, connected to ``input``
    """

    if isinstance(transform, np.ndarray) and transform.ndim == 2:
        size_in = transform.shape[0]
    elif shape_in is not None:
        size_in = np.prod(shape_in)
    else:
        size_in = input.size_out

    if isinstance(layer_func, NeuronType):
        node = Ensemble(size_in, 1, neuron_type=layer_func,
                        **layer_args).neurons
    else:
        # add (ignored) time input and pass kwargs
        def node_func(_, x):
            return layer_func(x, **layer_args)

        # reshape input if necessary
        if shape_in is not None:
            node_func = reshaped(shape_in)(node_func)

        node = TensorNode(node_func, size_in=size_in)

    conn = Connection(input, node, synapse=synapse, transform=transform)

    return (node, conn) if return_conn else node
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号