def pre_train_r(self):
print('=' * 50)
if cfg.ref_pre_path:
print('Loading R_pre from %s' % cfg.ref_pre_path)
self.R.load_state_dict(torch.load(cfg.ref_pre_path))
return
# we first train the R? network with just self-regularization loss for 1,000 steps
print('pre-training the refiner network %d times...' % cfg.r_pretrain)
for index in range(cfg.r_pretrain):
syn_image_batch, _ = self.syn_train_loader.__iter__().next()
syn_image_batch = Variable(syn_image_batch).cuda(cfg.cuda_num)
self.R.train()
ref_image_batch = self.R(syn_image_batch)
r_loss = self.self_regularization_loss(ref_image_batch, syn_image_batch)
# r_loss = torch.div(r_loss, cfg.batch_size)
r_loss = torch.mul(r_loss, self.delta)
self.opt_R.zero_grad()
r_loss.backward()
self.opt_R.step()
# log every `log_interval` steps
if (index % cfg.r_pre_per == 0) or (index == cfg.r_pretrain - 1):
# figure_name = 'refined_image_batch_pre_train_step_{}.png'.format(index)
print('[%d/%d] (R)reg_loss: %.4f' % (index, cfg.r_pretrain, r_loss.data[0]))
syn_image_batch, _ = self.syn_train_loader.__iter__().next()
syn_image_batch = Variable(syn_image_batch, volatile=True).cuda(cfg.cuda_num)
real_image_batch, _ = self.real_loader.__iter__().next()
real_image_batch = Variable(real_image_batch, volatile=True)
self.R.eval()
ref_image_batch = self.R(syn_image_batch)
figure_path = os.path.join(cfg.train_res_path, 'refined_image_batch_pre_train_%d.png' % index)
generate_img_batch(syn_image_batch.data.cpu(), ref_image_batch.data.cpu(),
real_image_batch.data, figure_path)
self.R.train()
print('Save R_pre to models/R_pre.pkl')
torch.save(self.R.state_dict(), 'models/R_pre.pkl')
评论列表
文章目录