eval_mvn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:sdp 作者: tansey 项目源码 文件源码
def main():
    parser = argparse.ArgumentParser(description='Predicts pixel intensities given a random subset of an image.')

    # Experiment settings
    parser.add_argument('--inputdir', default='experiments/pixels/data', help='The directory where the input data files will be stored.')
    parser.add_argument('--outputdir', default='experiments/pixels/results', help='The directory where the input data files will be stored.')
    parser.add_argument('--variable_scope', default='pixels-', help='The variable scope that the model will be created with.')
    parser.add_argument('--train_id', type=int, default=0, help='A trial ID. All models trained with the same trial ID will use the same train/validation datasets.')
    parser.add_argument('--train_samples', type=int, default=50000, help='The number of training examples to use.')
    parser.add_argument('--test_samples', type=int, default=10000, help='The number of training examples to use.')
    parser.add_argument('--validation_pct', type=float, default=0.2,
                                        help='The number of samples to hold out for a validation set. This is a percentage of the training samples.')
    parser.add_argument('--dimsize', type=int, default=256, help='The number of bins for each subpixel intensity (max 256, must be a power of 2).')
    parser.add_argument('--batchsize', type=int, default=50, help='The number of training samples per mini-batch.')

    # GMM/LMM settings
    parser.add_argument('--num_components', type=int, default=5, help='The number of mixture components for gmm or lmm models.')

    # Get the arguments from the command line
    args = parser.parse_args()
    dargs = vars(args)
    dargs['model'] = 'gmm'
    dargs['dataset'] = 'cifar'
    dargs['outfile'] = os.path.join(dargs['outputdir'], '{model}_{dataset}_{train_samples}_{num_components}_{train_id}'.format(**dargs))
    dargs['variable_scope'] = '{model}-{dataset}-{train_samples}-{num_components}-{train_id}'.format(**dargs)


    # Get the data
    from cifar_utils import DataLoader
    train_data = DataLoader(args.inputdir, 'train', args.train_samples, args.batchsize, seed=args.train_id, dimsize=args.dimsize)
    validate_data = DataLoader(args.inputdir, 'validate', args.train_samples, args.batchsize, seed=args.train_id, dimsize=args.dimsize)
    test_data = DataLoader(args.inputdir, 'test', args.test_samples, args.batchsize, seed=args.train_id, dimsize=args.dimsize)

    dargs['x_shape'] = train_data.x_shape()
    dargs['y_shape'] = train_data.y_shape()
    dargs['lazy_density'] = True # density is too big to enumerate for cifar
    dargs['one_hot'] = False # We use just the intensities not a one-hot

    sess = tf.Session(config=tf.ConfigProto(log_device_placement=False))

    # Get the X placeholder and the output distribution model
    tf_X, dist = create_model(**dargs)
    saver = tf.train.Saver()

    sess.run(tf.global_variables_initializer())

    # Reset the model back to the best version
    saver.restore(sess, dargs['outfile'])

    logprobs, rmse = explicit_score(sess, args.model, dist, test_data, tf_X)
    print logprobs, rmse
    np.savetxt(dargs['outfile'] + '_score.csv', [best_loss, logprobs, rmse, args.k, args.lam, args.num_components])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号