def encode(self, inputs, sequence_length, **kwargs):
scope = tf.get_variable_scope()
scope.set_initializer(tf.random_uniform_initializer(
-self.params["init_scale"],
self.params["init_scale"]))
cell_fw = training_utils.get_rnn_cell(**self.params["rnn_cell"])
cell_bw = training_utils.get_rnn_cell(**self.params["rnn_cell"])
outputs, states = tf.nn.bidirectional_dynamic_rnn(
cell_fw=cell_fw,
cell_bw=cell_bw,
inputs=inputs,
sequence_length=sequence_length,
dtype=tf.float32,
**kwargs)
# Concatenate outputs and states of the forward and backward RNNs
outputs_concat = tf.concat(outputs, 2)
return EncoderOutput(
outputs=outputs_concat,
final_state=states,
attention_values=outputs_concat,
attention_values_length=sequence_length)
评论列表
文章目录