def rnn(self, sequence, sequence_length, max_length, dropout, batch_size, training,
num_hidden=TC_MODEL_HIDDEN, num_layers=TC_MODEL_LAYERS):
# Recurrent network.
cell_fw = tf.nn.rnn_cell.GRUCell(num_hidden)
cell_bw = tf.nn.rnn_cell.GRUCell(num_hidden)
type = sequence.dtype
(fw_outputs, bw_outputs), _ = \
tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,
cell_bw=cell_bw,
initial_state_fw=cell_fw.zero_state(batch_size, type),
initial_state_bw=cell_bw.zero_state(batch_size, type),
inputs=sequence,
dtype=tf.float32,
swap_memory=True,
sequence_length=sequence_length)
sequence_output = tf.concat((fw_outputs, bw_outputs), 2)
# get last output of the dynamic_rnn
sequence_output = tf.reshape(sequence_output, [batch_size * max_length, num_hidden * 2])
indexes = tf.range(batch_size) * max_length + (sequence_length - 1)
output = tf.gather(sequence_output, indexes)
return output
text_classification_model_simple_bidirectional.py 文件源码
python
阅读 39
收藏 0
点赞 0
评论 0
评论列表
文章目录