def global_attention(state, hidden_states, encoder, encoder_input_length, scope=None, context=None, **kwargs):
with tf.variable_scope(scope or 'attention_{}'.format(encoder.name)):
if context is not None and encoder.use_context:
state = tf.concat([state, context], axis=1)
if encoder.attn_filters:
e = compute_energy_with_filter(hidden_states, state, attn_size=encoder.attn_size,
attn_filters=encoder.attn_filters,
attn_filter_length=encoder.attn_filter_length, **kwargs)
else:
e = compute_energy(hidden_states, state, attn_size=encoder.attn_size,
attn_keep_prob=encoder.attn_keep_prob, pervasive_dropout=encoder.pervasive_dropout,
layer_norm=encoder.layer_norm, mult_attn=encoder.mult_attn, **kwargs)
e -= tf.reduce_max(e, axis=1, keep_dims=True)
mask = tf.sequence_mask(encoder_input_length, maxlen=tf.shape(hidden_states)[1], dtype=tf.float32)
T = encoder.attn_temperature or 1.0
exp = tf.exp(e / T) * mask
weights = exp / tf.reduce_sum(exp, axis=-1, keep_dims=True)
weighted_average = tf.reduce_sum(tf.expand_dims(weights, 2) * hidden_states, axis=1)
return weighted_average, weights
评论列表
文章目录