def __init__(self,
num_units,
num_dims=1,
input_dims=None,
output_dims=None,
priority_dims=None,
non_recurrent_dims=None,
tied=False,
cell_fn=None,
non_recurrent_fn=None):
"""Initialize the parameters of a Grid RNN cell
Args:
num_units: int, The number of units in all dimensions of this GridRNN cell
num_dims: int, Number of dimensions of this grid.
input_dims: int or list, List of dimensions which will receive input data.
output_dims: int or list, List of dimensions from which the output will be
recorded.
priority_dims: int or list, List of dimensions to be considered as
priority dimensions.
If None, no dimension is prioritized.
non_recurrent_dims: int or list, List of dimensions that are not
recurrent.
The transfer function for non-recurrent dimensions is specified
via `non_recurrent_fn`,
which is default to be `tensorflow.nn.relu`.
tied: bool, Whether to share the weights among the dimensions of this
GridRNN cell.
If there are non-recurrent dimensions in the grid, weights are
shared between each
group of recurrent and non-recurrent dimensions.
cell_fn: function, a function which returns the recurrent cell object. Has
to be in the following signature:
def cell_func(num_units, input_size):
# ...
and returns an object of type `RNNCell`. If None, LSTMCell with
default parameters will be used.
non_recurrent_fn: a tensorflow Op that will be the transfer function of
the non-recurrent dimensions
"""
if num_dims < 1:
raise ValueError('dims must be >= 1: {}'.format(num_dims))
self._config = _parse_rnn_config(num_dims, input_dims, output_dims,
priority_dims, non_recurrent_dims,
non_recurrent_fn or nn.relu, tied,
num_units)
cell_input_size = (self._config.num_dims - 1) * num_units
if cell_fn is None:
self._cell = rnn_cell.LSTMCell(
num_units=num_units, input_size=cell_input_size, state_is_tuple=False)
else:
self._cell = cell_fn(num_units, cell_input_size)
if not isinstance(self._cell, rnn_cell.RNNCell):
raise ValueError('cell_fn must return an object of type RNNCell')
评论列表
文章目录