def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM) with hypernetworks and layer normalization."""
with vs.variable_scope(scope or type(self).__name__):
# Parameters of gates are concatenated into one multiply for efficiency.
total_h, total_c = tf.split(1, 2, state)
h = total_h[:, 0:self._num_units]
c = total_c[:, 0:self._num_units]
self.hyper_state = tf.concat(1, [total_h[:, self._num_units:], total_c[:, self._num_units:]])
hyper_input = tf.concat(1, [inputs, h])
hyper_output, hyper_new_state = self.hyper_cell(hyper_input, self.hyper_state)
self.hyper_output = hyper_output
self.hyper_state = hyper_new_state
input_below_ = rnn_cell._linear([inputs],
4 * self._num_units, False, scope="out_1")
input_below_ = self.hyper_norm(input_below_, 4 * self._num_units, scope="hyper_x")
state_below_ = rnn_cell._linear([h],
4 * self._num_units, False, scope="out_2")
state_below_ = self.hyper_norm(state_below_, 4 * self._num_units, scope="hyper_h")
if self.is_layer_norm:
s1 = vs.get_variable("s1", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32)
s2 = vs.get_variable("s2", initializer=tf.ones([4 * self._num_units]), dtype=tf.float32)
s3 = vs.get_variable("s3", initializer=tf.ones([self._num_units]), dtype=tf.float32)
b1 = vs.get_variable("b1", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32)
b2 = vs.get_variable("b2", initializer=tf.zeros([4 * self._num_units]), dtype=tf.float32)
b3 = vs.get_variable("b3", initializer=tf.zeros([self._num_units]), dtype=tf.float32)
input_below_ = ln(input_below_, s1, b1)
state_below_ = ln(state_below_, s2, b2)
lstm_matrix = tf.add(input_below_, state_below_)
i, j, f, o = array_ops.split(1, 4, lstm_matrix)
new_c = (c * sigmoid(f) + sigmoid(i) *
self._activation(j))
# Currently normalizing c causes lot of nan's in the model, thus commenting it out for now.
# new_c_ = ln(new_c, s3, b3)
new_c_ = new_c
new_h = self._activation(new_c_) * sigmoid(o)
hyper_h, hyper_c = tf.split(1, 2, hyper_new_state)
new_total_h = tf.concat(1, [new_h, hyper_h])
new_total_c = tf.concat(1, [new_c, hyper_c])
new_total_state = tf.concat(1, [new_total_h, new_total_c])
return new_h, new_total_state
评论列表
文章目录