main.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:SimGAN_pytorch 作者: AlexHex7 项目源码 文件源码
def pre_train_d(self):
        print('=' * 50)
        if cfg.disc_pre_path:
            print('Loading D_pre from %s' % cfg.disc_pre_path)
            self.D.load_state_dict(torch.load(cfg.disc_pre_path))
            return

        # and D? for 200 steps (one mini-batch for refined images, another for real)
        print('pre-training the discriminator network %d times...' % cfg.r_pretrain)

        self.D.train()
        self.R.eval()
        for index in range(cfg.d_pretrain):
            real_image_batch, _ = self.real_loader.__iter__().next()
            real_image_batch = Variable(real_image_batch).cuda(cfg.cuda_num)

            syn_image_batch, _ = self.syn_train_loader.__iter__().next()
            syn_image_batch = Variable(syn_image_batch).cuda(cfg.cuda_num)

            assert real_image_batch.size(0) == syn_image_batch.size(0)

            # ============ real image D ====================================================
            # self.D.train()
            d_real_pred = self.D(real_image_batch).view(-1, 2)

            d_real_y = Variable(torch.zeros(d_real_pred.size(0)).type(torch.LongTensor)).cuda(cfg.cuda_num)
            d_ref_y = Variable(torch.ones(d_real_pred.size(0)).type(torch.LongTensor)).cuda(cfg.cuda_num)

            acc_real = calc_acc(d_real_pred, 'real')
            d_loss_real = self.local_adversarial_loss(d_real_pred, d_real_y)
            # d_loss_real = torch.div(d_loss_real, cfg.batch_size)

            # ============ syn image D ====================================================
            # self.R.eval()
            ref_image_batch = self.R(syn_image_batch)

            # self.D.train()
            d_ref_pred = self.D(ref_image_batch).view(-1, 2)

            acc_ref = calc_acc(d_ref_pred, 'refine')
            d_loss_ref = self.local_adversarial_loss(d_ref_pred, d_ref_y)
            # d_loss_ref = torch.div(d_loss_ref, cfg.batch_size)

            d_loss = d_loss_real + d_loss_ref
            self.opt_D.zero_grad()
            d_loss.backward()
            self.opt_D.step()

            if (index % cfg.d_pre_per == 0) or (index == cfg.d_pretrain - 1):
                print('[%d/%d] (D)d_loss:%f  acc_real:%.2f%% acc_ref:%.2f%%'
                      % (index, cfg.d_pretrain, d_loss.data[0], acc_real, acc_ref))

        print('Save D_pre to models/D_pre.pkl')
        torch.save(self.D.state_dict(), 'models/D_pre.pkl')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号