seq2seq.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:website-fingerprinting 作者: AxelGoetz 项目源码 文件源码
def _init_decoder(self):
        """
        Creates decoder attributes.
        We cannot simply use a dynamic_rnn since we are feeding the outputs of the
        decoder back into the inputs.
        Therefore we use a raw_rnn and emulate a dynamic_rnn with this behavior.
        (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn.py)
        """
        # EOS token added
        self.decoder_inputs_length = self.encoder_inputs_length + 1

        def loop_fn_initial(time, cell_output, cell_state, loop_state):
            elements_finished = (time >= self.decoder_inputs_length)

            # EOS token (0 + self.EOS)
            initial_input = tf.zeros([self.batch_size, self.decoder_cell.output_size], dtype=tf.float32) + self.EOS
            initial_cell_state = self.encoder_final_state
            initial_loop_state = None  # we don't need to pass any additional information

            return (elements_finished,
                    initial_input,
                    initial_cell_state,
                    None,  # cell output is dummy here
                    initial_loop_state)

        def loop_fn(time, cell_output, cell_state, loop_state):
            if cell_output is None:  # time == 0
                return loop_fn_initial(time, cell_output, cell_state, loop_state)

            cell_output.set_shape([self.batch_size, self.decoder_cell.output_size])

            emit_output = cell_output

            next_cell_state = cell_state

            elements_finished = (time >= self.decoder_inputs_length)
            finished = tf.reduce_all(elements_finished)

            next_input = tf.cond(
                finished,
                lambda: tf.zeros([self.batch_size, self.decoder_cell.output_size], dtype=tf.float32), # self.PAD
                lambda: cell_output # Use the input from the previous cell
            )

            next_loop_state = None

            return (
                elements_finished,
                next_input,
                next_cell_state,
                emit_output,
                next_loop_state
            )

        decoder_outputs_ta, decoder_final_state, _ = tf.nn.raw_rnn(self.decoder_cell, loop_fn)
        self.decoder_outputs = decoder_outputs_ta.stack()
        self.decoder_outputs = tf.transpose(self.decoder_outputs, [1, 0, 2])

        with tf.variable_scope('DecoderOutputProjection') as scope:
            self.decoder_outputs = self.projection(self.decoder_outputs, self.seq_width, scope)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号