def rnn(self, sequence, sequence_length, max_length, dropout, batch_size, training,
num_hidden=TC_MODEL_HIDDEN, num_layers=TC_MODEL_LAYERS):
# Recurrent network.
cells = []
for _ in range(num_layers):
cell = tf.nn.rnn_cell.GRUCell(num_hidden)
if training:
cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=dropout)
cells.append(cell)
network = tf.nn.rnn_cell.MultiRNNCell(cells)
type = sequence.dtype
sequence_output, _ = tf.nn.dynamic_rnn(network, sequence, dtype=tf.float32,
sequence_length=sequence_length,
initial_state=network.zero_state(batch_size, type))
# get last output of the dynamic_rnn
sequence_output = tf.reshape(sequence_output, [batch_size * max_length, num_hidden])
indexes = tf.range(batch_size) * max_length + (sequence_length - 1)
output = tf.gather(sequence_output, indexes)
return output
text_classification_model_simple.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录