production_ex_train.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:TensorFlow-Machine-Learning-Cookbook 作者: PacktPublishing 项目源码 文件源码
def rnn_model(x_data_ph, max_sequence_length, vocab_size, embedding_size,
              rnn_size, dropout_keep_prob):
    # Create embedding
    embedding_mat = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0))
    embedding_output = tf.nn.embedding_lookup(embedding_mat, x_data_ph)

    # Define the RNN cell
    cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size)
    output, state = tf.nn.dynamic_rnn(cell, embedding_output, dtype=tf.float32)
    output = tf.nn.dropout(output, dropout_keep_prob)

    # Get output of RNN sequence
    output = tf.transpose(output, [1, 0, 2])
    last = tf.gather(output, int(output.get_shape()[0]) - 1)


    weight = tf.Variable(tf.truncated_normal([rnn_size, 2], stddev=0.1))
    bias = tf.Variable(tf.constant(0.1, shape=[2]))
    logits_out = tf.nn.softmax(tf.matmul(last, weight) + bias)

    return(logits_out)


# Define accuracy function
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号