def calculate_mean_edit_distance_and_loss(batch_set, dropout):
r'''
This routine beam search decodes a mini-batch and calculates the loss and mean edit distance.
Next to total and average loss it returns the mean edit distance,
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_x, batch_seq_len, batch_y = batch_set.next_batch()
# Calculate the logits of the batch using BiRNN
logits = BiRNN(batch_x, tf.to_int64(batch_seq_len), dropout)
# Compute the CTC loss using either TensorFlow's `ctc_loss` or Baidu's `warp_ctc_loss`.
if FLAGS.use_warpctc:
total_loss = tf.contrib.warpctc.warp_ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
else:
total_loss = tf.nn.ctc_loss(labels=batch_y, inputs=logits, sequence_length=batch_seq_len)
# Calculate the average loss across the batch
avg_loss = tf.reduce_mean(total_loss)
# Beam search decode the batch
decoded, _ = tf.nn.ctc_beam_search_decoder(logits, batch_seq_len, merge_repeated=False)
# Compute the edit (Levenshtein) distance
distance = tf.edit_distance(tf.cast(decoded[0], tf.int32), batch_y)
# Compute the mean edit distance
mean_edit_distance = tf.reduce_mean(distance)
# Finally we return the
# - calculated total and
# - average losses,
# - the Levenshtein distance,
# - the recognition mean edit distance,
# - the decoded batch and
# - the original batch_y (which contains the verified transcriptions).
return total_loss, avg_loss, distance, mean_edit_distance, decoded, batch_y
# Adam Optimization
# =================
# In constrast to 'Deep Speech: Scaling up end-to-end speech recognition'
# (http://arxiv.org/abs/1412.5567),
# in which 'Nesterov's Accelerated Gradient Descent'
# (www.cs.toronto.edu/~fritz/absps/momentum.pdf) was used,
# we will use the Adam method for optimization (http://arxiv.org/abs/1412.6980),
# because, generally, it requires less fine-tuning.
评论列表
文章目录