infer.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dawn-bench-models 作者: stanford-futuredata 项目源码 文件源码
def infer(dataset_dir, run_dir, output_file, start, end, repeat, log2,
          cpu, gpu, append, models):

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ])

    testset = datasets.CIFAR10(root=dataset_dir, train=False, download=True,
                               transform=transform_test)
    models = models or os.listdir(run_dir)
    output_path = os.path.join(run_dir, output_file)
    assert not os.path.exists(output_path) or append
    for model in models:
        model_dir = os.path.join(run_dir, model)
        paths = glob(f"{model_dir}/*/checkpoint_best_model.t7")
        assert len(paths) > 0
        path = os.path.abspath(paths[0])

        print(f'Model: {model}')
        print(f'Path: {path}')

        if cpu:
            print('With CPU:')
            engine = PyTorchEngine(path, use_cuda=False, arch=model)
            infer_cifar10(testset, engine, start=start, end=end, log2=log2,
                          repeat=repeat, output=output_path)

        if gpu and torch.cuda.is_available():
            print('With GPU:')
            engine = PyTorchEngine(path, use_cuda=True, arch=model)
            # Warmup
            time_batch_size(testset, 1, engine.pred, engine.use_cuda, repeat=1)

            infer_cifar10(testset, engine, start=start, end=end, log2=log2,
                          repeat=repeat, output=output_path)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号