main.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:SimGAN_pytorch 作者: AlexHex7 项目源码 文件源码
def pre_train_r(self):
        print('=' * 50)
        if cfg.ref_pre_path:
            print('Loading R_pre from %s' % cfg.ref_pre_path)
            self.R.load_state_dict(torch.load(cfg.ref_pre_path))
            return

        # we first train the R? network with just self-regularization loss for 1,000 steps
        print('pre-training the refiner network %d times...' % cfg.r_pretrain)

        for index in range(cfg.r_pretrain):
            syn_image_batch, _ = self.syn_train_loader.__iter__().next()
            syn_image_batch = Variable(syn_image_batch).cuda(cfg.cuda_num)

            self.R.train()
            ref_image_batch = self.R(syn_image_batch)

            r_loss = self.self_regularization_loss(ref_image_batch, syn_image_batch)
            # r_loss = torch.div(r_loss, cfg.batch_size)
            r_loss = torch.mul(r_loss, self.delta)

            self.opt_R.zero_grad()
            r_loss.backward()
            self.opt_R.step()

            # log every `log_interval` steps
            if (index % cfg.r_pre_per == 0) or (index == cfg.r_pretrain - 1):
                # figure_name = 'refined_image_batch_pre_train_step_{}.png'.format(index)
                print('[%d/%d] (R)reg_loss: %.4f' % (index, cfg.r_pretrain, r_loss.data[0]))

                syn_image_batch, _ = self.syn_train_loader.__iter__().next()
                syn_image_batch = Variable(syn_image_batch, volatile=True).cuda(cfg.cuda_num)

                real_image_batch, _ = self.real_loader.__iter__().next()
                real_image_batch = Variable(real_image_batch, volatile=True)

                self.R.eval()
                ref_image_batch = self.R(syn_image_batch)

                figure_path = os.path.join(cfg.train_res_path, 'refined_image_batch_pre_train_%d.png' % index)
                generate_img_batch(syn_image_batch.data.cpu(), ref_image_batch.data.cpu(),
                                   real_image_batch.data, figure_path)
                self.R.train()

                print('Save R_pre to models/R_pre.pkl')
                torch.save(self.R.state_dict(), 'models/R_pre.pkl')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号