def lstm_cell(i, o, state):
m_input = tf.pack([i for _ in range(m_rows)])
m_saved_output = tf.pack([o for _ in range(m_rows)])
m_all = tf.batch_matmul(m_input, m_input_w) + tf.batch_matmul(m_saved_output, m_middle) + m_biases
m_all = tf.unpack(m_all)
input_gate = tf.sigmoid(m_all[m_input_index])
forget_gate = tf.sigmoid(m_all[m_forget_index])
update = m_all[m_update_index]
state = forget_gate * state + input_gate * tf.tanh(update)
output_gate = tf.sigmoid(m_all[m_output_index])
return output_gate * tf.tanh(state), state
评论列表
文章目录