def lstm_cell(X, output, state):
"""Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf
Note that in this formulation, we omit the various connections between the
previous state and the gates."""
X_output = tf.concat(1, [X, output])
all_logits = tf.matmul(X_output, W_lstm) + b_lstm
input_gate = tf.sigmoid(all_logits[:, :NUM_NODES])
forget_gate = tf.sigmoid(all_logits[:, NUM_NODES: NUM_NODES * 2])
output_gate = tf.sigmoid(all_logits[:, NUM_NODES * 2: NUM_NODES * 3])
temp_state = all_logits[:, NUM_NODES * 3:]
state = forget_gate * state + input_gate * tf.tanh(temp_state)
return output_gate * tf.tanh(state), state
# Input data.
评论列表
文章目录