def tensor_layer(input, layer_func, shape_in=None, synapse=None,
transform=1, return_conn=False, **layer_args):
"""A utility function to construct TensorNodes that apply some function
to their input (analogous to the ``tf.layers`` syntax).
Parameters
----------
input : :class:`~nengo:nengo.base.NengoObject`
Object providing input to the layer
layer_func : callable or :class:`~nengo:nengo.neurons.NeuronType`
A function that takes the value from ``input`` (represented as a
``tf.Tensor``) and maps it to some output value, or a Nengo neuron
type, defining a nonlinearity that will be applied to ``input``.
shape_in : tuple of int, optional
If not None, reshape the input to the given shape
synapse : float or :class:`~nengo:nengo.synapses.Synapse`, optional
Synapse to apply on connection from ``input`` to this layer
transform : :class:`~numpy:numpy.ndarray`, optional
Transform matrix to apply on connection from ``input`` to this layer
return_conn : bool, optional
If True, also return the connection linking this layer to ``input``
layer_args : dict, optional
These arguments will be passed to ``layer_func`` if it is callable, or
:class:`~nengo:nengo.Ensemble` if ``layer_func`` is a
:class:`~nengo:nengo.neurons.NeuronType`
Returns
-------
:class:`.TensorNode` or :class:`~nengo:nengo.ensemble.Neurons`
A TensorNode that implements the given layer function (if
``layer_func`` was a callable), or a Neuron object with the given
neuron type, connected to ``input``
"""
if isinstance(transform, np.ndarray) and transform.ndim == 2:
size_in = transform.shape[0]
elif shape_in is not None:
size_in = np.prod(shape_in)
else:
size_in = input.size_out
if isinstance(layer_func, NeuronType):
node = Ensemble(size_in, 1, neuron_type=layer_func,
**layer_args).neurons
else:
# add (ignored) time input and pass kwargs
def node_func(_, x):
return layer_func(x, **layer_args)
# reshape input if necessary
if shape_in is not None:
node_func = reshaped(shape_in)(node_func)
node = TensorNode(node_func, size_in=size_in)
conn = Connection(input, node, synapse=synapse, transform=transform)
return (node, conn) if return_conn else node
评论列表
文章目录