def create_mask_for_keys(self, keys, keys_length):
# batch_size x keys_l
mask = 1 - tf.sequence_mask(lengths=keys_length, maxlen=keys.get_shape().as_list()[1], dtype=tf.float32)
mask *= -2 ** 30
mask = tf.expand_dims(tf.expand_dims(mask, 1), 1) # batch_size x 1 x 1 x keys_l
return mask
评论列表
文章目录