def main():
args = parse_args()
config = Config(args)
# ??????
os.makedirs(config.output_dir, exist_ok=True)
# ??????
model = models.generate_model(config.model)
# ????????
img_orig = load_image(config.original_image, [config.width, config.height])
img_style = load_image(config.style_image, [config.width, config.height] if not config.no_resize_style else None)
# ?????
generator = models.Generator(model, img_orig, img_style, config)
generator.generate(config)
python类Generator()的实例源码
def main():
args = parse_args()
config = Config(args)
# ??????
os.makedirs(config.output_dir, exist_ok=True)
# ??????
model = models.generate_model(config.model)
# ????????
img_orig = load_image(config.original_image, [config.width, config.height])
img_style = load_image(config.style_image, [config.width, config.height] if not config.no_resize_style else None)
# ?????
generator = models.Generator(model, img_orig, img_style, config)
generator.generate(config)
def train(args):
nz = args.nz
batch_size = args.batch_size
epochs = args.epochs
gpu = args.gpu
# CIFAR-10 images in range [-1, 1] (tanh generator outputs)
train, _ = datasets.get_cifar10(withlabel=False, ndim=3, scale=2)
train -= 1.0
train_iter = iterators.SerialIterator(train, batch_size)
z_iter = RandomNoiseIterator(GaussianNoiseGenerator(0, 1, args.nz),
batch_size)
optimizer_generator = optimizers.RMSprop(lr=0.00005)
optimizer_critic = optimizers.RMSprop(lr=0.00005)
optimizer_generator.setup(Generator())
optimizer_critic.setup(Critic())
updater = WassersteinGANUpdater(
iterator=train_iter,
noise_iterator=z_iter,
optimizer_generator=optimizer_generator,
optimizer_critic=optimizer_critic,
device=gpu)
trainer = training.Trainer(updater, stop_trigger=(epochs, 'epoch'))
trainer.extend(extensions.ProgressBar())
trainer.extend(extensions.LogReport(trigger=(1, 'iteration')))
trainer.extend(GeneratorSample(), trigger=(1, 'epoch'))
trainer.extend(extensions.PrintReport(['epoch', 'iteration', 'critic/loss',
'critic/loss/real', 'critic/loss/fake', 'generator/loss']))
trainer.run()
def __init__(self, nic, noc, ngf, ndf, beta=0.5, lamb=100, lr=1e-3, cuda=True, crayon=False):
"""
Args:
nic: Number of input channel
noc: Number of output channels
ngf: Number of generator filters
ndf: Number of discriminator filters
lamb: Weight on L1 term in objective
"""
self.cuda = cuda
self.start_epoch = 0
self.crayon = crayon
if crayon:
self.cc = CrayonClient(hostname="localhost", port=8889)
try:
self.logger = self.cc.create_experiment('pix2pix')
except:
self.cc.remove_experiment('pix2pix')
self.logger = self.cc.create_experiment('pix2pix')
self.gen = self.cudafy(Generator(nic, noc, ngf))
self.dis = self.cudafy(Discriminator(nic, noc, ndf))
# Optimizers for generators
self.gen_optim = self.cudafy(optim.Adam(
self.gen.parameters(), lr=lr, betas=(beta, 0.999)))
# Optimizers for discriminators
self.dis_optim = self.cudafy(optim.Adam(
self.dis.parameters(), lr=lr, betas=(beta, 0.999)))
# Loss functions
self.criterion_bce = nn.BCELoss()
self.criterion_mse = nn.MSELoss()
self.criterion_l1 = nn.L1Loss()
self.lamb = lamb
def train(self, loader, c_epoch):
self.dis.train()
self.gen.train()
self.reset_gradients()
max_idx = len(loader)
for idx, features in enumerate(tqdm(loader)):
orig_x = Variable(self.cudafy(features[0]))
orig_y = Variable(self.cudafy(features[1]))
""" Discriminator """
# Train with real
self.dis.volatile = False
dis_real = self.dis(torch.cat((orig_x, orig_y), 1))
real_labels = Variable(self.cudafy(
torch.ones(dis_real.size())
))
dis_real_loss = self.criterion_bce(
dis_real, real_labels)
# Train with fake
gen_y = self.gen(orig_x)
dis_fake = self.dis(torch.cat((orig_x, gen_y.detach()), 1))
fake_labels = Variable(self.cudafy(
torch.zeros(dis_fake.size())
))
dis_fake_loss = self.criterion_bce(
dis_fake, fake_labels)
# Update weights
dis_loss = dis_real_loss + dis_fake_loss
dis_loss.backward()
self.dis_optim.step()
self.reset_gradients()
""" Generator """
self.dis.volatile = True
dis_real = self.dis(torch.cat((orig_x, gen_y), 1))
real_labels = Variable(self.cudafy(
torch.ones(dis_real.size())
))
gen_loss = self.criterion_bce(dis_real, real_labels) + \
self.lamb * self.criterion_l1(gen_y, orig_y)
gen_loss.backward()
self.gen_optim.step()
# Pycrayon or nah
if self.crayon:
self.logger.add_scalar_value('pix2pix_gen_loss', gen_loss.data[0])
self.logger.add_scalar_value('pix2pix_dis_loss', dis_loss.data[0])
if idx % 50 == 0:
tqdm.write('Epoch: {} [{}/{}]\t'
'D Loss: {:.4f}\t'
'G Loss: {:.4f}'.format(
c_epoch, idx, max_idx, dis_loss.data[0], gen_loss.data[0]
))