beam_aligner.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:almond-nnparser 作者: Stanford-Mobisocial-IoT-Lab 项目源码 文件源码
def finalize(self, outputs : BeamSearchOptimizationDecoderOutput, final_state : BeamSearchOptimizationDecoderState, sequence_lengths):
        # all output fields are [max_time, batch_size, ...]
        predicted_ids = tf.contrib.seq2seq.gather_tree(
            outputs.predicted_ids, outputs.parent_ids,
            sequence_length=sequence_lengths, name='predicted_ids')
        total_loss = tf.reduce_sum(outputs.loss, axis=0, name='violation_loss')

        predicted_time = tf.shape(predicted_ids)[0]
        last_score = predicted_time-1
        with tf.name_scope('gold_score'):
            gold_score = outputs.gold_score[last_score]
        with tf.name_scope('sequence_scores'):
            sequence_scores = outputs.scores[last_score]

        return FinalBeamSearchOptimizationDecoderOutput(beam_search_decoder_output=outputs,
                                                        predicted_ids=predicted_ids,
                                                        scores=sequence_scores,
                                                        gold_score=gold_score,
                                                        gold_beam_id=final_state.gold_beam_id,
                                                        num_available_beams=final_state.num_available_beams,
                                                        total_violation_loss=total_loss), final_state
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号