def compute_attention_mask(x_mask, mem_mask, x_word_dim, key_word_dim):
""" computes a (batch, x_word_dim, key_word_dim) bool mask for clients that want masking """
if x_mask is None and mem_mask is None:
return None
elif x_mask is None or mem_mask is None:
raise NotImplementedError()
x_mask = tf.sequence_mask(x_mask, x_word_dim)
mem_mask = tf.sequence_mask(mem_mask, key_word_dim)
join_mask = tf.logical_and(tf.expand_dims(x_mask, 2), tf.expand_dims(mem_mask, 1))
return join_mask
评论列表
文章目录