def lstm_cell1(i, o, state):
"""Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf
Note that in this formulation, we omit the various connections between the
previous state and the gates."""
m_input2 = tf.pack([i for _ in range(m_rows)])
m_saved_output2 = tf.pack([o for _ in range(m_rows)])
# m_input2 = tf.nn.dropout(m_input2, keep_prob)
m_all = tf.batch_matmul(m_input2, m_input_w2) + tf.batch_matmul(m_saved_output2, m_middle_w2) + m_biases
m_all = tf.unpack(m_all)
input_gate = tf.sigmoid(m_all[m_input_index])
forget_gate = tf.sigmoid(m_all[m_forget_index])
update = m_all[m_update_index]
state = forget_gate * state + input_gate * tf.tanh(update)
output_gate = tf.sigmoid(m_all[m_output_index])
return output_gate * tf.tanh(state), state
# Input data.
评论列表
文章目录