def encode(self, features, labels):
features["source_ids"] = tf.reverse_sequence(features["source_ids"], features["source_len"], batch_dim=0, seq_dim=1) # [[1,2,3,4,PAD,PAD,PAD],[2,3,PAD,PAD,PAD,PAD,PAD]] [4,2]
features["source_ids"] = tf.reverse(features["source_ids"],[1]) # --> [[4,3,2,1,PAD,PAD,PAD],[3,2,PAD,PAD,PAD,PAD,PAD]] --> [[PAD,PAD,PAD,1,2,3,4],[PAD,PAD,PAD,PAD,PAD,2,3]]
source_embedded = tf.nn.embedding_lookup(self.source_embedding_fairseq(),
features["source_ids"])
encoder_fn = self.encoder_class(self.params["encoder.params"], self.mode, self.source_pos_embedding_fairseq())
return encoder_fn(source_embedded, features["source_len"])
评论列表
文章目录