def __init__(self, num_units, forget_bias=1.0, input_size=None,
state_is_tuple=False, activation=tanh, hyper_num_units=128, hyper_embedding_size=32, is_layer_norm = True):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
hyper_num_units: int, The number of units in the HyperLSTM cell.
forget_bias: float, The bias added to forget gates (see above).
input_size: Deprecated and unused.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
if not state_is_tuple:
print("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
if input_size is not None:
print("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
self.hyper_num_units = hyper_num_units
self.total_num_units = self._num_units + self.hyper_num_units
self.hyper_cell = rnn_cell.BasicLSTMCell(hyper_num_units)
self.hyper_embedding_size= hyper_embedding_size
self.is_layer_norm = is_layer_norm
评论列表
文章目录