def main():
parser = buildArgsParser()
args = parser.parse_args()
# Load experiments hyperparameters
try:
hyperparams = smartutils.load_dict_from_json_file(pjoin(args.experiment, "hyperparams.json"))
except:
hyperparams = smartutils.load_dict_from_json_file(pjoin(args.experiment, '..', "hyperparams.json"))
model = load_model(args.experiment)
print(str(model))
with Timer("Generating {} samples from Conv Deep NADE".format(args.count)):
sample = model.build_sampling_function(seed=args.seed)
samples, probs = sample(args.count, return_probs=True, ordering_seed=args.seed)
if args.out is not None:
outfile = pjoin(args.experiment, args.out)
with Timer("Saving {0} samples to '{1}'".format(args.count, outfile)):
np.save(outfile, samples)
if args.view:
import pylab as plt
from convnade import vizu
if hyperparams["dataset"] == "binarized_mnist":
image_shape = (28, 28)
else:
raise ValueError("Unknown dataset: {0}".format(hyperparams["dataset"]))
plt.figure()
data = vizu.concatenate_images(samples, shape=image_shape, border_size=1, clim=(0, 1))
plt.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
plt.title("Samples")
plt.figure()
data = vizu.concatenate_images(probs, shape=image_shape, border_size=1, clim=(0, 1))
plt.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
plt.title("Probs")
plt.show()
评论列表
文章目录