def _build_model(self, batch_size, helper_build_fn, decoder_maxiters=None, alignment_history=False):
# embed input_data into a one-hot representation
inputs = tf.one_hot(self.input_data, self._input_size, dtype=self._dtype)
inputs_len = self.input_lengths
with tf.name_scope('bidir-encoder'):
fw_cell = rnn.MultiRNNCell([rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)], state_is_tuple=True)
bw_cell = rnn.MultiRNNCell([rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)], state_is_tuple=True)
fw_cell_zero = fw_cell.zero_state(batch_size, self._dtype)
bw_cell_zero = bw_cell.zero_state(batch_size, self._dtype)
enc_out, _ = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, inputs,
sequence_length=inputs_len,
initial_state_fw=fw_cell_zero,
initial_state_bw=bw_cell_zero)
with tf.name_scope('attn-decoder'):
dec_cell_in = rnn.GRUCell(self._dec_rnn_size)
attn_values = tf.concat(enc_out, 2)
attn_mech = seq2seq.BahdanauAttention(self._enc_rnn_size * 2, attn_values, inputs_len)
dec_cell_attn = rnn.GRUCell(self._enc_rnn_size * 2)
dec_cell_attn = seq2seq.AttentionWrapper(dec_cell_attn,
attn_mech,
self._enc_rnn_size * 2,
alignment_history=alignment_history)
dec_cell_out = rnn.GRUCell(self._output_size)
dec_cell = rnn.MultiRNNCell([dec_cell_in, dec_cell_attn, dec_cell_out],
state_is_tuple=True)
dec = seq2seq.BasicDecoder(dec_cell, helper_build_fn(),
dec_cell.zero_state(batch_size, self._dtype))
dec_out, dec_state = seq2seq.dynamic_decode(dec, output_time_major=False,
maximum_iterations=decoder_maxiters, impute_finished=True)
self.outputs = dec_out.rnn_output
self.output_ids = dec_out.sample_id
self.final_state = dec_state
评论列表
文章目录