def ctc_loss(label, prob, remainder, seq_length, batch_size, num_gpu=1, big_num=1e10):
label_ = [0, 0]
prob[prob < 1 / big_num] = 1 / big_num
log_prob = np.log(prob)
l = len(label)
for i in range(l):
label_.append(int(label[i]))
label_.append(0)
l_ = 2 * l + 1
a = np.full((seq_length, l_ + 1), -big_num)
a[0][1] = log_prob[remainder][0]
a[0][2] = log_prob[remainder][label_[2]]
for i in range(1, seq_length):
row = i * int(batch_size / num_gpu) + remainder
a[i][1] = a[i - 1][1] + log_prob[row][0]
a[i][2] = np.logaddexp(a[i - 1][2], a[i - 1][1]) + log_prob[row][label_[2]]
for j in range(3, l_ + 1):
a[i][j] = np.logaddexp(a[i - 1][j], a[i - 1][j - 1])
if label_[j] != 0 and label_[j] != label_[j - 2]:
a[i][j] = np.logaddexp(a[i][j], a[i - 1][j - 2])
a[i][j] += log_prob[row][label_[j]]
return -np.logaddexp(a[seq_length - 1][l_], a[seq_length - 1][l_ - 1])
# label is done with remove_blank
# pred is got from pred_best
评论列表
文章目录