def ctc_batch_cost(y_true, y_pred, input_length, label_length):
"""Runs CTC loss algorithm on each batch element.
Arguments:
y_true: tensor `(samples, max_string_length)`
containing the truth labels.
y_pred: tensor `(samples, time_steps, num_categories)`
containing the prediction, or output of the softmax.
input_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_pred`.
label_length: tensor `(samples, 1)` containing the sequence length for
each batch item in `y_true`.
Returns:
Tensor with shape (samples,1) containing the
CTC loss of each element.
"""
label_length = math_ops.to_int32(array_ops.squeeze(label_length))
input_length = math_ops.to_int32(array_ops.squeeze(input_length))
sparse_labels = math_ops.to_int32(
ctc_label_dense_to_sparse(y_true, label_length))
y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
return array_ops.expand_dims(
ctc.ctc_loss(
inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
评论列表
文章目录