attention_decoder.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:master-thesis 作者: AndreasMadsen 项目源码 文件源码
def attention_decoder(enc, length, state_transfer_helper,
                      voca_size=20, max_length=None,
                      name=None, reuse=None):
    with tf.variable_scope(name, "attention-decoder", values=[enc, length],
                           reuse=reuse) as scope:
        # get shapes
        batch_size = enc.get_shape().as_list()[0]
        if batch_size is None:
            batch_size = tf.shape(enc)[0]

        dims = int(enc.get_shape()[-1])

        # decoder
        dec_attn = seq2seq.DynamicAttentionWrapper(
            cell=rnn.GRUCell(dims, reuse=scope.reuse),
            attention_mechanism=seq2seq.LuongAttention(dims, enc, length),
            attention_size=dims
        )

        dec_network = rnn.MultiRNNCell([
            rnn.GRUCell(dims, reuse=scope.reuse),
            dec_attn,
            rnn.GRUCell(voca_size, reuse=scope.reuse)
        ], state_is_tuple=True)

        decoder = seq2seq.BasicDecoder(
            dec_network, state_transfer_helper(),
            initial_state=dec_network.zero_state(batch_size, tf.float32)
        )

        dec_outputs, _ = seq2seq.dynamic_decode(
            decoder,
            maximum_iterations=max_length,
            impute_finished=False
        )

        logits = dec_outputs.rnn_output
        labels = dec_outputs.sample_id

        # pad logits and labels
        if max_length is not None:
            logits = dynamic_time_pad(logits, max_length)
            labels = dynamic_time_pad(labels, max_length)

        return logits, labels
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号