def __call__(self, input, state, scope=None): # TODO test
with tf.variable_scope(scope or type(self).__name__):
# computation
c_prev, h_prev = state
with tf.variable_scope('mul'):
concat = _linear([input, h_prev], 2 * self._num_units, True)
proj_input, rec_input = tf.split(value=concat, num_or_size_splits=2, axis=1)
mul_input = proj_input * rec_input # equation (18)
with tf.variable_scope('rec_input'):
rec_mul_input = _linear(mul_input, 4 * self._num_units, True)
b = tf.get_variable('b', [self._num_units * 4])
lstm_mat = input + rec_mul_input + b
i, j, f, o = tf.split(value=lstm_mat, num_or_size_splits=4, axis=1)
# new_c, new_h
new_c = (c_prev * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * tf.nn.tanh(j))
new_h = tf.nn.tanh(new_c) * tf.nn.sigmoid(o)
new_state = (LSTMStateTuple(new_c, new_h))
return new_h, new_state
评论列表
文章目录