def train(self, loader, c_epoch):
self.dis.train()
self.gen.train()
self.reset_gradients()
max_idx = len(loader)
for idx, features in enumerate(tqdm(loader)):
orig_x = Variable(self.cudafy(features[0]))
orig_y = Variable(self.cudafy(features[1]))
""" Discriminator """
# Train with real
self.dis.volatile = False
dis_real = self.dis(torch.cat((orig_x, orig_y), 1))
real_labels = Variable(self.cudafy(
torch.ones(dis_real.size())
))
dis_real_loss = self.criterion_bce(
dis_real, real_labels)
# Train with fake
gen_y = self.gen(orig_x)
dis_fake = self.dis(torch.cat((orig_x, gen_y.detach()), 1))
fake_labels = Variable(self.cudafy(
torch.zeros(dis_fake.size())
))
dis_fake_loss = self.criterion_bce(
dis_fake, fake_labels)
# Update weights
dis_loss = dis_real_loss + dis_fake_loss
dis_loss.backward()
self.dis_optim.step()
self.reset_gradients()
""" Generator """
self.dis.volatile = True
dis_real = self.dis(torch.cat((orig_x, gen_y), 1))
real_labels = Variable(self.cudafy(
torch.ones(dis_real.size())
))
gen_loss = self.criterion_bce(dis_real, real_labels) + \
self.lamb * self.criterion_l1(gen_y, orig_y)
gen_loss.backward()
self.gen_optim.step()
# Pycrayon or nah
if self.crayon:
self.logger.add_scalar_value('pix2pix_gen_loss', gen_loss.data[0])
self.logger.add_scalar_value('pix2pix_dis_loss', dis_loss.data[0])
if idx % 50 == 0:
tqdm.write('Epoch: {} [{}/{}]\t'
'D Loss: {:.4f}\t'
'G Loss: {:.4f}'.format(
c_epoch, idx, max_idx, dis_loss.data[0], gen_loss.data[0]
))
评论列表
文章目录