conv_decoder_fairseq.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:conv_seq2seq 作者: tobyyouup 项目源码 文件源码
def _create_position_embedding(self, lengths, maxlen):

    # Slice to size of current sequence
    pe_slice = self.pos_embed[2:maxlen+2, :]
    # Replicate encodings for each element in the batch
    batch_size = tf.shape(lengths)[0]
    pe_batch = tf.tile([pe_slice], [batch_size, 1, 1])

    # Mask out positions that are padded
    positions_mask = tf.sequence_mask(
        lengths=lengths, maxlen=maxlen, dtype=tf.float32)
    positions_embed = pe_batch * tf.expand_dims(positions_mask, 2)

    return positions_embed
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号