def main(args):
# if we enabled GPU mode, set the GPU to use
if args.device_id >= 0:
chainer.cuda.get_device(args.device_id).use()
# Load dataset (we will only use the training set)
if args.mnist:
train, test = chainer.datasets.get_mnist(withlabel=False, scale=2, ndim=3)
generator = GeneratorMNIST()
else:
train, test = chainer.datasets.get_cifar10(withlabel=False, scale=2, ndim=3)
generator = GeneratorCIFAR()
# subtracting 1, after scaling to 2 (done above) will make all pixels in the range [-1,1]
train -= 1.0
num_training_samples = train.shape[0]
# make data iterators
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
# build optimizers and models
opt_generator = chainer.optimizers.RMSprop(lr=args.learning_rate)
opt_discriminator = chainer.optimizers.RMSprop(lr=args.learning_rate)
opt_generator.setup(generator)
opt_discriminator.setup(Discriminator())
# make a random noise iterator (uniform noise between -1 and 1)
noise_iter = RandomNoiseIterator(UniformNoiseGenerator(-1, 1, args.num_z), args.batchsize)
# send to GPU
if args.device_id >= 0:
opt_generator.target.to_gpu()
opt_discriminator.target.to_gpu()
# make the output folder
if not os.path.exists(args.output):
os.makedirs(args.output, exist_ok=True)
print("Starting training loop...")
while train_iter.epoch < args.num_epochs:
training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator)
print("Finished training.")
评论列表
文章目录