def average_attention(hidden_states, encoder_input_length, *args, **kwargs):
# attention with fixed weights (average of all hidden states)
lengths = tf.to_float(tf.expand_dims(encoder_input_length, axis=1))
mask = tf.sequence_mask(encoder_input_length, maxlen=tf.shape(hidden_states)[1])
weights = tf.to_float(mask) / lengths
weighted_average = tf.reduce_sum(hidden_states * tf.expand_dims(weights, axis=2), axis=1)
return weighted_average, weights
评论列表
文章目录