transfer_learning.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ModelZoo 作者: NervanaSystems 项目源码 文件源码
def main():
    # Collect the user arguments and hyper parameters
    args, hyper_params = get_args_and_hyperparameters()

    np.set_printoptions( precision=8, suppress=True, edgeitems=6, threshold=2048)

    # setup the CPU or GPU backend
    be = gen_backend(**extract_valid_args(args, gen_backend))

    # load the training dataset. This will download the dataset from the web and cache it
    # locally for subsequent use.
    train_set = MultiscaleSampler('trainval', '2007', samples_per_img=hyper_params.samples_per_img, 
                                 sample_height=224, path=args.data_dir, 
                                 samples_per_batch=hyper_params.samples_per_batch,
                                 max_imgs = hyper_params.max_train_imgs,
                                 shuffle = hyper_params.shuffle)

    # create the model by replacing the classification layer of AlexNet with 
    # new adaptation layers
    model, opt = create_model( args, hyper_params)

    # Seed the Alexnet conv layers with pre-trained weights
    if args.model_file is None and hyper_params.use_pre_trained_weights:
        load_imagenet_weights(model, args.data_dir)

    train( args, hyper_params, model, opt, train_set)

    # Load the test dataset. This will download the dataset from the web and cache it
    # locally for subsequent use.
    test_set = MultiscaleSampler('test', '2007', samples_per_img=hyper_params.samples_per_img, 
                                 sample_height=224, path=args.data_dir, 
                                 samples_per_batch=hyper_params.samples_per_batch,
                                 max_imgs = hyper_params.max_test_imgs,
                                 shuffle = hyper_params.shuffle)
    test( args, hyper_params, model, test_set)

    return

# parse the command line arguments
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号