train.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:taas-examples 作者: caicloud 项目源码 文件源码
def lstm(X):
    batch_size = tf.shape(X)[0]

    w_in = tf.Variable(tf.random_normal([NUM_FEATURES, FLAGS.rnn_hidden_nodes], seed=SEED))
    b_in = tf.Variable(tf.constant(0.1, shape=[FLAGS.rnn_hidden_nodes]))

    input = tf.reshape(X, [-1, NUM_FEATURES])

    input_rnn = tf.matmul(input, w_in) + b_in
    input_rnn = tf.reshape(input_rnn, [-1, FLAGS.rnn_num_steps, FLAGS.rnn_hidden_nodes])
    cell = rnn.BasicLSTMCell(FLAGS.rnn_hidden_nodes, state_is_tuple=True)

    init_state = cell.zero_state(batch_size, dtype=tf.float32)
    output_rnn, final_states = tf.nn.dynamic_rnn(cell, input_rnn, initial_state=init_state, dtype=tf.float32)
    output = output_rnn[:, -1, :]

    w_out = tf.Variable(tf.random_normal([FLAGS.rnn_hidden_nodes, 1], seed=SEED))
    b_out = tf.Variable(tf.constant(0.1, shape=[1]))
    pred = tf.matmul(output, w_out) + b_out
    return pred
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号