def pre_train_d(self):
print('=' * 50)
if cfg.disc_pre_path:
print('Loading D_pre from %s' % cfg.disc_pre_path)
self.D.load_state_dict(torch.load(cfg.disc_pre_path))
return
# and D? for 200 steps (one mini-batch for refined images, another for real)
print('pre-training the discriminator network %d times...' % cfg.r_pretrain)
self.D.train()
self.R.eval()
for index in range(cfg.d_pretrain):
real_image_batch, _ = self.real_loader.__iter__().next()
real_image_batch = Variable(real_image_batch).cuda(cfg.cuda_num)
syn_image_batch, _ = self.syn_train_loader.__iter__().next()
syn_image_batch = Variable(syn_image_batch).cuda(cfg.cuda_num)
assert real_image_batch.size(0) == syn_image_batch.size(0)
# ============ real image D ====================================================
# self.D.train()
d_real_pred = self.D(real_image_batch).view(-1, 2)
d_real_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda(cfg.cuda_num)
d_ref_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda(cfg.cuda_num)
acc_real = calc_acc(d_real_pred, 'real')
d_loss_real = self.local_adversarial_loss(d_real_pred, d_real_y)
# d_loss_real = torch.div(d_loss_real, cfg.batch_size)
# ============ syn image D ====================================================
# self.R.eval()
ref_image_batch = self.R(syn_image_batch)
# self.D.train()
d_ref_pred = self.D(ref_image_batch).view(-1, 2)
acc_ref = calc_acc(d_ref_pred, 'refine')
d_loss_ref = self.local_adversarial_loss(d_ref_pred, d_ref_y)
# d_loss_ref = torch.div(d_loss_ref, cfg.batch_size)
d_loss = d_loss_real + d_loss_ref
self.opt_D.zero_grad()
d_loss.backward()
self.opt_D.step()
if (index % cfg.d_pre_per == 0) or (index == cfg.d_pretrain - 1):
print('[%d/%d] (D)d_loss:%f acc_real:%.2f%% acc_ref:%.2f%%'
% (index, cfg.d_pretrain, d_loss.data[0], acc_real, acc_ref))
print('Save D_pre to models/D_pre.pkl')
torch.save(self.D.state_dict(), 'models/D_pre.pkl')
评论列表
文章目录