attack.py 文件源码

python
阅读 98 收藏 0 点赞 0 评论 0

项目:pytorch-nips2017-attack-example 作者: rwightman 项目源码 文件源码
def run_attack(args, attack):
    assert args.input_dir

    if args.targeted:
        dataset = Dataset(
            args.input_dir,
            transform=default_inception_transform(args.img_size))
    else:
        dataset = Dataset(
            args.input_dir,
            target_file='',
            transform=default_inception_transform(args.img_size))

    loader = data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False)

    model = torchvision.models.inception_v3(pretrained=False, transform_input=False)
    if not args.no_gpu:
        model = model.cuda()

    if args.checkpoint_path is not None and os.path.isfile(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
    else:
        print("Error: No checkpoint found at %s." % args.checkpoint_path)

    model.eval()

    for batch_idx, (input, target) in enumerate(loader):
        if not args.no_gpu:
            input = input.cuda()
            target = target.cuda()

        input_adv = attack.run(model, input, target, batch_idx)

        start_index = args.batch_size * batch_idx
        indices = list(range(start_index, start_index + input.size(0)))
        for filename, o in zip(dataset.filenames(indices, basename=True), input_adv):
            output_file = os.path.join(args.output_dir, filename)
            imsave(output_file, (o + 1.0) * 0.5, format='png')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号