multihead_attention.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:attention 作者: louishenrifranc 项目源码 文件源码
def create_mask_for_keys(self, keys, keys_length):
        # batch_size x keys_l
        mask = 1 - tf.sequence_mask(lengths=keys_length, maxlen=keys.get_shape().as_list()[1], dtype=tf.float32)
        mask *= -2 ** 30
        mask = tf.expand_dims(tf.expand_dims(mask, 1), 1)  # batch_size x 1 x 1 x keys_l
        return mask
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号