test.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:chainer-speech-recognition 作者: musyoku 项目源码 文件源码
def main():
    # ????????
    vocab, vocab_inv, BLANK = get_vocab()
    vocab_size = len(vocab)

    # ???????????????
    # GTX 1080 1???
    batchsizes = [96, 64, 64, 64, 64, 64, 64, 64, 48, 48, 48, 32, 32, 24, 24, 24, 24, 24, 24, 24, 24, 24]

    augmentation = AugmentationOption()
    if args.augmentation:
        augmentation.change_vocal_tract = True
        augmentation.change_speech_rate = True
        augmentation.add_noise = True

    model = load_model(args.model_dir)
    assert model is not None


    if args.gpu_device >= 0:
        chainer.cuda.get_device(args.gpu_device).use()
        model.to_gpu(args.gpu_device)
    xp = model.xp

    # ???
    with chainer.using_config("train", False):
        iterator = TestMinibatchIterator(wav_path_test, trn_path_test, cache_path, batchsizes, BLANK, buckets_limit=args.buckets_limit, option=augmentation, gpu=args.gpu_device >= 0)
        buckets_errors = []
        for batch in iterator:
            x_batch, x_length_batch, t_batch, t_length_batch, bucket_idx, progress = batch

            if args.filter_bucket_id and bucket_idx != args.filter_bucket_id:
                continue

            sys.stdout.write("\r" + stdout.CLEAR)
            sys.stdout.write("computing CER of bucket {} ({} %)".format(bucket_idx + 1, int(progress * 100)))
            sys.stdout.flush()

            y_batch = model(x_batch, split_into_variables=False)
            y_batch = xp.argmax(y_batch.data, axis=2)
            error = compute_minibatch_error(y_batch, t_batch, BLANK, print_sequences=True, vocab=vocab_inv)

            while bucket_idx >= len(buckets_errors):
                buckets_errors.append([])

            buckets_errors[bucket_idx].append(error)

        avg_errors = []
        for errors in buckets_errors:
            avg_errors.append(sum(errors) / len(errors))

        sys.stdout.write("\r" + stdout.CLEAR)
        sys.stdout.flush()

        print_bold("bucket  CER")
        for bucket_idx, error in enumerate(avg_errors):
            print("{}   {}".format(bucket_idx + 1, error * 100))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号