def lstm(X):
batch_size = tf.shape(X)[0]
w_in = tf.Variable(tf.random_normal([NUM_FEATURES, FLAGS.rnn_hidden_nodes], seed=SEED))
b_in = tf.Variable(tf.constant(0.1, shape=[FLAGS.rnn_hidden_nodes]))
input = tf.reshape(X, [-1, NUM_FEATURES])
input_rnn = tf.matmul(input, w_in) + b_in
input_rnn = tf.reshape(input_rnn, [-1, FLAGS.rnn_num_steps, FLAGS.rnn_hidden_nodes])
cell = rnn.BasicLSTMCell(FLAGS.rnn_hidden_nodes, state_is_tuple=True)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
output_rnn, final_states = tf.nn.dynamic_rnn(cell, input_rnn, initial_state=init_state, dtype=tf.float32)
output = output_rnn[:, -1, :]
w_out = tf.Variable(tf.random_normal([FLAGS.rnn_hidden_nodes, 1], seed=SEED))
b_out = tf.Variable(tf.constant(0.1, shape=[1]))
pred = tf.matmul(output, w_out) + b_out
return pred
评论列表
文章目录