def _ctc_normal(self, predict,labels):
n = labels.shape[0]
labels2 = T.concatenate((labels, [self.tpo["CTC_blank"], self.tpo["CTC_blank"]]))
sec_diag = T.neq(labels2[:-2], labels2[2:]) * \
T.eq(labels2[1:-1], self.tpo["CTC_blank"])
recurrence_relation = \
T.eye(n) + \
T.eye(n, k=1) + \
T.eye(n, k=2) * sec_diag.dimshuffle((0, 'x'))
pred_y = predict[:, labels]
probabilities, _ = theano.scan(
lambda curr, accum: curr * T.dot(accum, recurrence_relation),
sequences=[pred_y],
outputs_info=[T.eye(n)[0]]
)
labels_probab = T.sum(probabilities[-1, -2:])
return -T.log(labels_probab)
评论列表
文章目录