def gram_ctc(xs, label_unigram, label_bigram, blank_symbol, input_length=None, length_unigram=None, reduce='mean'):
if not isinstance(xs, collections.Sequence):
raise TypeError('xs must be a list of Variables')
if not isinstance(blank_symbol, int):
raise TypeError('blank_symbol must be non-negative integer.')
assert blank_symbol >= 0
assert blank_symbol < xs[0].shape[1]
assert len(xs[0].shape) == 2
assert label_unigram.shape[1] == label_bigram.shape[1]
if input_length is None:
xp = cuda.get_array_module(xs[0].data)
input_length = variable.Variable(xp.full((len(xs[0].data),), len(xs), dtype=np.int32))
length_unigram = variable.Variable(xp.full((len(label_unigram.data),), len(label_unigram.data[0]), dtype=np.int32))
return GramCTC(blank_symbol, reduce)(input_length, length_unigram, label_unigram, label_bigram, *xs)
评论列表
文章目录