ctc_cost.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:KGP-ASR 作者: KGPML 项目源码 文件源码
def _labeling_batch_to_class_batch(y, y_labeling, num_classes,
                                   y_hat_mask=None):
    # FIXME: y_hat_mask is currently not used
    batch_size = y.shape[1]
    N = y_labeling.shape[0]
    n_labels = y.shape[0]
    # sum over all repeated labels
    # from (T, B, L) to (T, C, B)
    out = T.zeros((num_classes, batch_size, N))
    y_labeling = y_labeling.dimshuffle((2, 1, 0))  # L, B, T
    y_ = y

    def scan_step(index, prev_res, y_labeling, y_):
        res_t = T.inc_subtensor(prev_res[y_[index, T.arange(batch_size)],
                                T.arange(batch_size)],
                                y_labeling[index, T.arange(batch_size)])
        return res_t

    result, updates = theano.scan(scan_step,
                                  sequences=[T.arange(n_labels)],
                                  non_sequences=[y_labeling, y_],
                                  outputs_info=[out])
    # result will be (C, B, T) so we make it (T, B, C)
    return result[-1].dimshuffle(2, 1, 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号