main.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:RaSoR-in-Tensorflow 作者: YerevaNN 项目源码 文件源码
def BiLSTM(input, input_mask, name):
    with tf.variable_scope(name):
        lstm_fw_cell = rnn.LSTMCell(n_hidden, forget_bias=1.0)
        lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, state_keep_prob=1.0-dropout,
#                                                     input_keep_prob=1.0-dropout, input_size=tf.shape(input)[1:],
                                                     variational_recurrent=True, dtype=tf.float32)
        lstm_bw_cell = rnn.LSTMCell(n_hidden, forget_bias=1.0)
        lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, state_keep_prob=1.0-dropout,
#                                                     input_keep_prob=1.0-dropout, input_size=tf.shape(input)[1:],
                                                     variational_recurrent=True,dtype=tf.float32)
        outputs, states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input, dtype=tf.float32)
    outputs = tf.concat(outputs, axis=-1) * input_mask
    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号