tf-keras-skeleton.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:LIE 作者: EmbraceLife 项目源码 文件源码
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
      """Runs CTC loss algorithm on each batch element.

      Arguments:
          y_true: tensor `(samples, max_string_length)`
              containing the truth labels.
          y_pred: tensor `(samples, time_steps, num_categories)`
              containing the prediction, or output of the softmax.
          input_length: tensor `(samples, 1)` containing the sequence length for
              each batch item in `y_pred`.
          label_length: tensor `(samples, 1)` containing the sequence length for
              each batch item in `y_true`.

      Returns:
          Tensor with shape (samples,1) containing the
              CTC loss of each element.
      """
      label_length = math_ops.to_int32(array_ops.squeeze(label_length))
      input_length = math_ops.to_int32(array_ops.squeeze(input_length))
      sparse_labels = math_ops.to_int32(
          ctc_label_dense_to_sparse(y_true, label_length))

      y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)

      return array_ops.expand_dims(
          ctc.ctc_loss(
              inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号