def main():
import random
# Model class options
model_parser = argparse.ArgumentParser(description='Model Parameters', add_help=False)
model_parser.add_argument('--model_name', type=str, help='Model name {"SimpleCNN", "MiddleCNN"}')
model_parser.add_argument('--init_model', type=str, help='Initialize the model from given file')
model_parser.add_argument('--n_classes', type=int, default=48, help='Number of classes')
model_args, remaining_argv = model_parser.parse_known_args()
# Model runtime options
runtime_parser = argparse.ArgumentParser(description='Runtime Parameters', add_help=False)
runtime_parser.add_argument('--gpu', type=int, help='GPU ID (negative value indicates CPU')
runtime_parser.add_argument('--test_dir', type=str, help='/path/to/test_dir')
runtime_parser.add_argument('--nb_output', type=int, default=10, help='Number of output images')
runtime_parser.add_argument('--save_dir', type=str, default='./grad_cam', help='Save directory')
runtime_args, remaining_argv = runtime_parser.parse_known_args(remaining_argv)
# merge options
parser = argparse.ArgumentParser(
description='Visualize Saliency',
parents=[model_parser, runtime_parser])
parser.add_argument('--debug', action='store_true', help='if specified, using chainer.set_debug()')
args = parser.parse_args()
chainer.set_debug(args.debug)
assert model_args.init_model is not None, "init_model must be specified."
# load model
grad_cam = build_gradcam_model(args.n_classes, args.model_name, args.init_model, args.gpu)
''' Visualization '''
for idx in range(len(fonts_dict)):
target_dir = os.path.join(args.test_dir, fonts_dict[idx])
if not os.path.isdir(target_dir):
continue
filenames = sorted(os.listdir(target_dir))
si = list(range(len(filenames)))
random.shuffle(si)
for j in range(args.nb_output):
filename = filenames[si[j]]
img = imread(os.path.join(target_dir, filename), mode='RGB').astype(np.float32)
arr = convert_to_array(img, args.gpu)
mask, pred_idx = grad_cam(arr, None)
if idx == pred_idx:
save_dir = os.path.join(args.save_dir, fonts_dict[idx])
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
save_cam_image(img, mask, os.path.join(save_dir, filename))
else:
print("true :", fonts_dict[idx], "!= predict :", fonts_dict[pred_idx])
评论列表
文章目录