def main():
# Collect the user arguments and hyper parameters
args, hyper_params = get_args_and_hyperparameters()
np.set_printoptions( precision=8, suppress=True, edgeitems=6, threshold=2048)
# setup the CPU or GPU backend
be = gen_backend(**extract_valid_args(args, gen_backend))
# load the training dataset. This will download the dataset from the web and cache it
# locally for subsequent use.
train_set = MultiscaleSampler('trainval', '2007', samples_per_img=hyper_params.samples_per_img,
sample_height=224, path=args.data_dir,
samples_per_batch=hyper_params.samples_per_batch,
max_imgs = hyper_params.max_train_imgs,
shuffle = hyper_params.shuffle)
# create the model by replacing the classification layer of AlexNet with
# new adaptation layers
model, opt = create_model( args, hyper_params)
# Seed the Alexnet conv layers with pre-trained weights
if args.model_file is None and hyper_params.use_pre_trained_weights:
load_imagenet_weights(model, args.data_dir)
train( args, hyper_params, model, opt, train_set)
# Load the test dataset. This will download the dataset from the web and cache it
# locally for subsequent use.
test_set = MultiscaleSampler('test', '2007', samples_per_img=hyper_params.samples_per_img,
sample_height=224, path=args.data_dir,
samples_per_batch=hyper_params.samples_per_batch,
max_imgs = hyper_params.max_test_imgs,
shuffle = hyper_params.shuffle)
test( args, hyper_params, model, test_set)
return
# parse the command line arguments
评论列表
文章目录