conv_decoder_fairseq.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:conv_seq2seq 作者: tobyyouup 项目源码 文件源码
def initialize(self, name=None):

    finished = tf.tile([False], [self.config.beam_width])

    start_tokens_batch = tf.fill([self.config.beam_width], self.start_tokens)
    first_inputs = tf.nn.embedding_lookup(self.target_embedding, start_tokens_batch)
    first_inputs = tf.expand_dims(first_inputs, 1)
    zeros_padding = tf.zeros([self.config.beam_width, self.params['max_decode_length']-1, self.target_embedding.get_shape().as_list()[-1]])
    first_inputs = tf.concat([first_inputs, zeros_padding], axis=1)

    outputs = tf.tile(self.initial_state.outputs, [self.config.beam_width,1,1]) 
    attention_values = tf.tile(self.initial_state.attention_values, [self.config.beam_width,1,1]) 
    enc_output = EncoderOutput(
        outputs=outputs,
        final_state=self.initial_state.final_state,
        attention_values=attention_values,
        attention_values_length=self.initial_state.attention_values_length)


    return finished, first_inputs, enc_output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号