conv_decoder_fairseq_bs.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:conv_seq2seq 作者: tobyyouup 项目源码 文件源码
def next_inputs(self, sample_ids,name=None):
    finished = math_ops.equal(sample_ids, self.config.eos_token)
    all_finished = math_ops.reduce_all(finished)
    next_inputs = control_flow_ops.cond(
        all_finished,
        # If we're finished, the next_inputs value doesn't matter
        lambda:  tf.nn.embedding_lookup(self.target_embedding, tf.tile([self.config.eos_token], [self.config.beam_width])),
        lambda: tf.nn.embedding_lookup(self.target_embedding, sample_ids))
    return all_finished, next_inputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号