def test(self):
def decode_greedily(beam_search: bool, merge_repeated: bool):
aa_ctc_blank_aa_logits = tf.constant(np.array([[[1.0, 0.0]], [[1.0, 0.0]], [[0.0, 1.0]],
[[1.0, 0.0]], [[1.0, 0.0]]], dtype=np.float32))
sequence_length = tf.constant(np.array([5], dtype=np.int32))
(decoded_list,), log_probabilities = \
tf.nn.ctc_beam_search_decoder(inputs=aa_ctc_blank_aa_logits,
sequence_length=sequence_length,
merge_repeated=merge_repeated,
beam_width=1) \
if beam_search else \
tf.nn.ctc_greedy_decoder(inputs=aa_ctc_blank_aa_logits,
sequence_length=sequence_length,
merge_repeated=merge_repeated)
return list(tf.Session().run(tf.sparse_tensor_to_dense(decoded_list)[0]))
self.assertEqual([0], decode_greedily(beam_search=True, merge_repeated=True))
self.assertEqual([0, 0], decode_greedily(beam_search=True, merge_repeated=False))
self.assertEqual([0, 0], decode_greedily(beam_search=False, merge_repeated=True))
self.assertEqual([0, 0, 0, 0], decode_greedily(beam_search=False, merge_repeated=False))
评论列表
文章目录