def apply(self, is_train, x, x_mask=None):
x_word_dim = tf.shape(x)[1]
# (batch, x_word, key_word)
dist_matrix = self.attention.get_scores(x, x)
dist_matrix += tf.expand_dims(tf.eye(x_word_dim) * VERY_NEGATIVE_NUMBER, 0) # Mask out self
joint_mask = compute_attention_mask(x_mask, x_mask, x_word_dim, x_word_dim)
if joint_mask is not None:
dist_matrix += VERY_NEGATIVE_NUMBER * (1 - tf.cast(joint_mask, dist_matrix.dtype))
if not self.alignment_bias:
select_probs = tf.nn.softmax(dist_matrix)
else:
# Allow zero-attention by adding a learned bias to the normalizer
bias = tf.exp(tf.get_variable("no-alignment-bias", initializer=tf.constant(-1.0, dtype=tf.float32)))
dist_matrix = tf.exp(dist_matrix)
select_probs = dist_matrix / (tf.reduce_sum(dist_matrix, axis=2, keep_dims=True) + bias)
response = tf.matmul(select_probs, x) # (batch, x_words, q_dim)
if self.merge is not None:
with tf.variable_scope("merge"):
response = self.merge.apply(is_train, response, x)
return response
else:
return response
评论列表
文章目录