def encode(self, inputs, input_length, _parses):
with tf.name_scope('BiLSTMEncoder'):
fw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)])
bw_cell_enc = tf.contrib.rnn.MultiRNNCell([self._make_rnn_cell(i) for i in range(self._num_layers)])
outputs, output_state = tf.nn.bidirectional_dynamic_rnn(fw_cell_enc, bw_cell_enc, inputs, input_length,
dtype=tf.float32)
fw_output_state, bw_output_state = output_state
# concat each element of the final state, so that we're compatible with a unidirectional
# decoder
output_state = nest.pack_sequence_as(fw_output_state, [tf.concat((x, y), axis=1) for x, y in zip(nest.flatten(fw_output_state), nest.flatten(bw_output_state))])
return tf.concat(outputs, axis=2), output_state
encoders.py 文件源码
python
阅读 33
收藏 0
点赞 0
评论 0
评论列表
文章目录