def _decode_lambda(self, args):
"""
Decoding within tensorflow graph.
In case kenlm_directory is specified, a modified version of tensorflow
(available at https://github.com/timediv/tensorflow-with-kenlm)
is needed to run that extends ctc_decode to use a kenlm decoder.
:return:
Most probable decoded sequence. Important: blank labels are returned as `-1`.
"""
import tensorflow as tf
prediction_batch, prediction_lengths = args
log_prediction_batch = tf.log(tf.transpose(prediction_batch, perm=[1, 0, 2]) + 1e-8)
prediction_length_batch = tf.to_int32(tf.squeeze(prediction_lengths, axis=[1]))
(decoded, log_prob) = self.ctc_get_decoded_and_log_probability_batch(log_prediction_batch,
prediction_length_batch)
return single([tf.sparse_to_dense(st.indices, st.dense_shape, st.values, default_value=-1) for st in decoded])
评论列表
文章目录