stt_metric.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:deepspeech.mxnet 作者: samsungsds-rnd 项目源码 文件源码
def ctc_loss(label, prob, remainder, seq_length, batch_size, num_gpu=1, big_num=1e10):
    label_ = [0, 0]
    prob[prob < 1 / big_num] = 1 / big_num
    log_prob = np.log(prob)

    l = len(label)
    for i in range(l):
        label_.append(int(label[i]))
        label_.append(0)

    l_ = 2 * l + 1
    a = np.full((seq_length, l_ + 1), -big_num)
    a[0][1] = log_prob[remainder][0]
    a[0][2] = log_prob[remainder][label_[2]]
    for i in range(1, seq_length):
        row = i * int(batch_size / num_gpu) + remainder
        a[i][1] = a[i - 1][1] + log_prob[row][0]
        a[i][2] = np.logaddexp(a[i - 1][2], a[i - 1][1]) + log_prob[row][label_[2]]
        for j in range(3, l_ + 1):
            a[i][j] = np.logaddexp(a[i - 1][j], a[i - 1][j - 1])
            if label_[j] != 0 and label_[j] != label_[j - 2]:
                a[i][j] = np.logaddexp(a[i][j], a[i - 1][j - 2])
            a[i][j] += log_prob[row][label_[j]]

    return -np.logaddexp(a[seq_length - 1][l_], a[seq_length - 1][l_ - 1])


# label is done with remove_blank
# pred is got from pred_best
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号