models.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: eske 项目源码 文件源码
def global_attention(state, hidden_states, encoder, encoder_input_length, scope=None, context=None, **kwargs):
    with tf.variable_scope(scope or 'attention_{}'.format(encoder.name)):
        if context is not None and encoder.use_context:
            state = tf.concat([state, context], axis=1)

        if encoder.attn_filters:
            e = compute_energy_with_filter(hidden_states, state, attn_size=encoder.attn_size,
                                           attn_filters=encoder.attn_filters,
                                           attn_filter_length=encoder.attn_filter_length, **kwargs)
        else:
            e = compute_energy(hidden_states, state, attn_size=encoder.attn_size,
                               attn_keep_prob=encoder.attn_keep_prob, pervasive_dropout=encoder.pervasive_dropout,
                               layer_norm=encoder.layer_norm, mult_attn=encoder.mult_attn, **kwargs)

        e -= tf.reduce_max(e, axis=1, keep_dims=True)
        mask = tf.sequence_mask(encoder_input_length, maxlen=tf.shape(hidden_states)[1], dtype=tf.float32)

        T = encoder.attn_temperature or 1.0
        exp = tf.exp(e / T) * mask
        weights = exp / tf.reduce_sum(exp, axis=-1, keep_dims=True)
        weighted_average = tf.reduce_sum(tf.expand_dims(weights, 2) * hidden_states, axis=1)

        return weighted_average, weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号