test_ctc_decoders.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:speechless 作者: JuliusKunze 项目源码 文件源码
def test(self):
        def decode_greedily(beam_search: bool, merge_repeated: bool):
            aa_ctc_blank_aa_logits = tf.constant(np.array([[[1.0, 0.0]], [[1.0, 0.0]], [[0.0, 1.0]],
                                                           [[1.0, 0.0]], [[1.0, 0.0]]], dtype=np.float32))
            sequence_length = tf.constant(np.array([5], dtype=np.int32))

            (decoded_list,), log_probabilities = \
                tf.nn.ctc_beam_search_decoder(inputs=aa_ctc_blank_aa_logits,
                                              sequence_length=sequence_length,
                                              merge_repeated=merge_repeated,
                                              beam_width=1) \
                    if beam_search else \
                    tf.nn.ctc_greedy_decoder(inputs=aa_ctc_blank_aa_logits,
                                             sequence_length=sequence_length,
                                             merge_repeated=merge_repeated)

            return list(tf.Session().run(tf.sparse_tensor_to_dense(decoded_list)[0]))

        self.assertEqual([0], decode_greedily(beam_search=True, merge_repeated=True))
        self.assertEqual([0, 0], decode_greedily(beam_search=True, merge_repeated=False))
        self.assertEqual([0, 0], decode_greedily(beam_search=False, merge_repeated=True))
        self.assertEqual([0, 0, 0, 0], decode_greedily(beam_search=False, merge_repeated=False))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号