def main():
# ????????
vocab, vocab_inv, BLANK = get_vocab()
vocab_size = len(vocab)
# ???????????????
# GTX 1080 1???
batchsizes = [96, 64, 64, 64, 64, 64, 64, 64, 48, 48, 48, 32, 32, 24, 24, 24, 24, 24, 24, 24, 24, 24]
augmentation = AugmentationOption()
if args.augmentation:
augmentation.change_vocal_tract = True
augmentation.change_speech_rate = True
augmentation.add_noise = True
model = load_model(args.model_dir)
assert model is not None
if args.gpu_device >= 0:
chainer.cuda.get_device(args.gpu_device).use()
model.to_gpu(args.gpu_device)
xp = model.xp
# ???
with chainer.using_config("train", False):
iterator = TestMinibatchIterator(wav_path_test, trn_path_test, cache_path, batchsizes, BLANK, buckets_limit=args.buckets_limit, option=augmentation, gpu=args.gpu_device >= 0)
buckets_errors = []
for batch in iterator:
x_batch, x_length_batch, t_batch, t_length_batch, bucket_idx, progress = batch
if args.filter_bucket_id and bucket_idx != args.filter_bucket_id:
continue
sys.stdout.write("\r" + stdout.CLEAR)
sys.stdout.write("computing CER of bucket {} ({} %)".format(bucket_idx + 1, int(progress * 100)))
sys.stdout.flush()
y_batch = model(x_batch, split_into_variables=False)
y_batch = xp.argmax(y_batch.data, axis=2)
error = compute_minibatch_error(y_batch, t_batch, BLANK, print_sequences=True, vocab=vocab_inv)
while bucket_idx >= len(buckets_errors):
buckets_errors.append([])
buckets_errors[bucket_idx].append(error)
avg_errors = []
for errors in buckets_errors:
avg_errors.append(sum(errors) / len(errors))
sys.stdout.write("\r" + stdout.CLEAR)
sys.stdout.flush()
print_bold("bucket CER")
for bucket_idx, error in enumerate(avg_errors):
print("{} {}".format(bucket_idx + 1, error * 100))
评论列表
文章目录