def training_step(args, train_iter, noise_iter, opt_generator, opt_discriminator):
noise_samples = get_batch(noise_iter, args.device_id)
# generate an image
generated = opt_generator.target(noise_samples)
# get a batch of the dataset
train_samples = get_batch(train_iter, args.device_id)
# update the discriminator
Dreal = opt_discriminator.target(train_samples)
Dgen = opt_discriminator.target(generated)
Dloss = 0.5 * (F.sum((Dreal - 1.0)**2) + F.sum(Dgen**2)) / args.batchsize
update_model(opt_discriminator, Dloss)
# update the generator
noise_samples = get_batch(noise_iter, args.device_id)
generated = opt_generator.target(noise_samples)
Gloss = 0.5 * F.sum((opt_discriminator.target(generated) - 1.0)**2) / args.batchsize
update_model(opt_generator, Gloss)
if train_iter.is_new_epoch:
print("[{}] Discriminator loss: {} Generator loss: {}".format(train_iter.epoch, Dloss.data, Gloss.data))
print_sample(os.path.join(args.output, "epoch_{}.png".format(train_iter.epoch)), noise_samples, opt_generator)
评论列表
文章目录