model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:tf-seq2seq-attn 作者: johncf 项目源码 文件源码
def _build_model(self, batch_size, helper_build_fn, decoder_maxiters=None, alignment_history=False):
        # embed input_data into a one-hot representation
        inputs = tf.one_hot(self.input_data, self._input_size, dtype=self._dtype)
        inputs_len = self.input_lengths

        with tf.name_scope('bidir-encoder'):
            fw_cell = rnn.MultiRNNCell([rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)], state_is_tuple=True)
            bw_cell = rnn.MultiRNNCell([rnn.BasicRNNCell(self._enc_rnn_size) for i in range(3)], state_is_tuple=True)
            fw_cell_zero = fw_cell.zero_state(batch_size, self._dtype)
            bw_cell_zero = bw_cell.zero_state(batch_size, self._dtype)

            enc_out, _ = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, inputs,
                                                         sequence_length=inputs_len,
                                                         initial_state_fw=fw_cell_zero,
                                                         initial_state_bw=bw_cell_zero)

        with tf.name_scope('attn-decoder'):
            dec_cell_in = rnn.GRUCell(self._dec_rnn_size)
            attn_values = tf.concat(enc_out, 2)
            attn_mech = seq2seq.BahdanauAttention(self._enc_rnn_size * 2, attn_values, inputs_len)
            dec_cell_attn = rnn.GRUCell(self._enc_rnn_size * 2)
            dec_cell_attn = seq2seq.AttentionWrapper(dec_cell_attn,
                                                     attn_mech,
                                                     self._enc_rnn_size * 2,
                                                     alignment_history=alignment_history)
            dec_cell_out = rnn.GRUCell(self._output_size)
            dec_cell = rnn.MultiRNNCell([dec_cell_in, dec_cell_attn, dec_cell_out],
                                        state_is_tuple=True)

            dec = seq2seq.BasicDecoder(dec_cell, helper_build_fn(),
                                       dec_cell.zero_state(batch_size, self._dtype))

            dec_out, dec_state = seq2seq.dynamic_decode(dec, output_time_major=False,
                    maximum_iterations=decoder_maxiters, impute_finished=True)

        self.outputs = dec_out.rnn_output
        self.output_ids = dec_out.sample_id
        self.final_state = dec_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号