basic_seq2seq.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:seq2seq 作者: google 项目源码 文件源码
def _decode_infer(self, decoder, bridge, _encoder_output, features, labels):
    """Runs decoding in inference mode"""
    batch_size = self.batch_size(features, labels)
    if self.use_beam_search:
      batch_size = self.params["inference.beam_search.beam_width"]

    target_start_id = self.target_vocab_info.special_vocab.SEQUENCE_START
    helper_infer = tf_decode_helper.GreedyEmbeddingHelper(
        embedding=self.target_embedding,
        start_tokens=tf.fill([batch_size], target_start_id),
        end_token=self.target_vocab_info.special_vocab.SEQUENCE_END)
    decoder_initial_state = bridge()
    return decoder(decoder_initial_state, helper_infer)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号