gram_ctc.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:chainer-speech-recognition 作者: musyoku 项目源码 文件源码
def gram_ctc(xs, label_unigram, label_bigram, blank_symbol, input_length=None, length_unigram=None, reduce='mean'):
    if not isinstance(xs, collections.Sequence):
        raise TypeError('xs must be a list of Variables')
    if not isinstance(blank_symbol, int):
        raise TypeError('blank_symbol must be non-negative integer.')
    assert blank_symbol >= 0
    assert blank_symbol < xs[0].shape[1]
    assert len(xs[0].shape) == 2
    assert label_unigram.shape[1] == label_bigram.shape[1]

    if input_length is None:
        xp = cuda.get_array_module(xs[0].data)
        input_length = variable.Variable(xp.full((len(xs[0].data),), len(xs), dtype=np.int32))
        length_unigram = variable.Variable(xp.full((len(label_unigram.data),), len(label_unigram.data[0]), dtype=np.int32))

    return GramCTC(blank_symbol, reduce)(input_length, length_unigram, label_unigram, label_bigram, *xs)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号