loss_function.py 文件源码

python
阅读 41 收藏 0 点赞 0 评论 0

项目:recnet 作者: joergfranke 项目源码 文件源码
def _ctc_normal(self, predict,labels):

        n = labels.shape[0]

        labels2 = T.concatenate((labels, [self.tpo["CTC_blank"], self.tpo["CTC_blank"]]))
        sec_diag = T.neq(labels2[:-2], labels2[2:]) * \
                   T.eq(labels2[1:-1], self.tpo["CTC_blank"])

        recurrence_relation = \
            T.eye(n) + \
            T.eye(n, k=1) + \
            T.eye(n, k=2) * sec_diag.dimshuffle((0, 'x'))

        pred_y = predict[:, labels]

        probabilities, _ = theano.scan(
            lambda curr, accum: curr * T.dot(accum, recurrence_relation),
            sequences=[pred_y],
            outputs_info=[T.eye(n)[0]]
        )

        labels_probab = T.sum(probabilities[-1, -2:])
        return -T.log(labels_probab)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号