def BiRNN(x, n_hidden):
# Prepare data shape to match `bidirectional_rnn` function requirements
# Current data input shape: (batch_size, n_steps, n_input)
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
# Permuting batch_size and n_steps
x = tf.transpose(x, [1, 0, 2])
# Reshape to (n_steps*batch_size, n_input)
x = tf.reshape(x, [-1, n_input])
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
x = tf.split(0, n_steps, x)
# Define lstm cells with tensorflow
# Forward direction cell
lstm_fw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Backward direction cell
lstm_bw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Get lstm cell output
outputs, _, _ = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)
return outputs
arch_gdashboard_med_rnn_silent.py 文件源码
python
阅读 22
收藏 0
点赞 0
评论 0
评论列表
文章目录