CTC.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:Theano-NN_Starter 作者: nightinwhite 项目源码 文件源码
def best_path_decoding(self, probs, probs_mask=None):
        # probs is T x B x C+1
        T = probs.shape[0]
        B = probs.shape[1]
        C = probs.shape[2]-1

        maxprob = probs.argmax(axis=2)
        is_double = tensor.eq(maxprob[:-1], maxprob[1:])
        maxprob = tensor.switch(tensor.concatenate([tensor.zeros((1,B)), is_double]),
                                C*tensor.ones_like(maxprob), maxprob)
        # maxprob = theano.printing.Print('maxprob')(maxprob.T).T

        # returns two values :
        # label : (T x) T x B
        # label_length : (T x) B
        def recursion(maxp, p_mask, label_length, label):
            nonzero = p_mask * tensor.neq(maxp, C)
            nonzero_id = nonzero.nonzero()[0]

            new_label = tensor.set_subtensor(label[label_length[nonzero_id], nonzero_id], maxp[nonzero_id])
            new_label_length = tensor.switch(nonzero, label_length + numpy.int32(1), label_length)

            return new_label_length, new_label

        [label_length, label], _ = scan(fn=recursion,
                                        sequences=[maxprob, probs_mask],
                                        outputs_info=[tensor.zeros((B,),dtype='int32'),-tensor.ones((T,B))])

        return label[-1], label_length[-1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号