conv_seq2seq.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:conv_seq2seq 作者: tobyyouup 项目源码 文件源码
def encode(self, features, labels):

    features["source_ids"] = tf.reverse_sequence(features["source_ids"], features["source_len"], batch_dim=0, seq_dim=1)  # [[1,2,3,4,PAD,PAD,PAD],[2,3,PAD,PAD,PAD,PAD,PAD]]   [4,2]
    features["source_ids"] = tf.reverse(features["source_ids"],[1])  # --> [[4,3,2,1,PAD,PAD,PAD],[3,2,PAD,PAD,PAD,PAD,PAD]] --> [[PAD,PAD,PAD,1,2,3,4],[PAD,PAD,PAD,PAD,PAD,2,3]]

    source_embedded = tf.nn.embedding_lookup(self.source_embedding_fairseq(),
                                             features["source_ids"])
    encoder_fn = self.encoder_class(self.params["encoder.params"], self.mode, self.source_pos_embedding_fairseq())
    return encoder_fn(source_embedded, features["source_len"])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号