sample_convnade.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:NADE 作者: MarcCote 项目源码 文件源码
def main():
    parser = buildArgsParser()
    args = parser.parse_args()

    # Load experiments hyperparameters
    try:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(args.experiment, "hyperparams.json"))
    except:
        hyperparams = smartutils.load_dict_from_json_file(pjoin(args.experiment, '..', "hyperparams.json"))

    model = load_model(args.experiment)
    print(str(model))

    with Timer("Generating {} samples from Conv Deep NADE".format(args.count)):
        sample = model.build_sampling_function(seed=args.seed)
        samples, probs = sample(args.count, return_probs=True, ordering_seed=args.seed)

    if args.out is not None:
        outfile = pjoin(args.experiment, args.out)
        with Timer("Saving {0} samples to '{1}'".format(args.count, outfile)):
            np.save(outfile, samples)

    if args.view:
        import pylab as plt
        from convnade import vizu
        if hyperparams["dataset"] == "binarized_mnist":
            image_shape = (28, 28)
        else:
            raise ValueError("Unknown dataset: {0}".format(hyperparams["dataset"]))

        plt.figure()
        data = vizu.concatenate_images(samples, shape=image_shape, border_size=1, clim=(0, 1))
        plt.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
        plt.title("Samples")

        plt.figure()
        data = vizu.concatenate_images(probs, shape=image_shape, border_size=1, clim=(0, 1))
        plt.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
        plt.title("Probs")

        plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号