decoder.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def step(self, time_, inputs, state, name=None):
        cell_output, cell_state = self.cell(inputs, state)
        cell_output_new, logits, attention_scores, attention_context = \
            self.compute_output(cell_output)

        if self.reverse_scores_lengths is not None:
            attention_scores = tf.reverse_sequence(
                input=attention_scores,
                seq_lengths=self.reverse_scores_lengths,
                seq_dim=1,
                batch_dim=0)

        sample_ids = self.helper.sample(
            time=time_, outputs=logits, state=cell_state)

        outputs = AttentionDecoderOutput(
            logits=logits,
            predicted_ids=sample_ids,
            cell_output=cell_output_new,
            attention_scores=attention_scores,
            attention_context=attention_context)

        finished, next_inputs, next_state = self.helper.next_inputs(
            time=time_, outputs=outputs, state=cell_state, sample_ids=sample_ids)

        return (outputs, next_state, next_inputs, finished)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号