def BiLSTM(input, input_mask, name):
with tf.variable_scope(name):
lstm_fw_cell = rnn.LSTMCell(n_hidden, forget_bias=1.0)
lstm_fw_cell = tf.contrib.rnn.DropoutWrapper(lstm_fw_cell, state_keep_prob=1.0-dropout,
# input_keep_prob=1.0-dropout, input_size=tf.shape(input)[1:],
variational_recurrent=True, dtype=tf.float32)
lstm_bw_cell = rnn.LSTMCell(n_hidden, forget_bias=1.0)
lstm_bw_cell = tf.contrib.rnn.DropoutWrapper(lstm_bw_cell, state_keep_prob=1.0-dropout,
# input_keep_prob=1.0-dropout, input_size=tf.shape(input)[1:],
variational_recurrent=True,dtype=tf.float32)
outputs, states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, input, dtype=tf.float32)
outputs = tf.concat(outputs, axis=-1) * input_mask
return outputs
评论列表
文章目录