trainer.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:rarepepes 作者: kendricktan 项目源码 文件源码
def train(self, loader, c_epoch):
        self.dis.train()
        self.gen.train()
        self.reset_gradients()

        max_idx = len(loader)
        for idx, features in enumerate(tqdm(loader)):
            orig_x = Variable(self.cudafy(features[0]))
            orig_y = Variable(self.cudafy(features[1]))

            """ Discriminator """
            # Train with real
            self.dis.volatile = False
            dis_real = self.dis(torch.cat((orig_x, orig_y), 1))
            real_labels = Variable(self.cudafy(
                torch.ones(dis_real.size())
            ))
            dis_real_loss = self.criterion_bce(
                dis_real, real_labels)

            # Train with fake
            gen_y = self.gen(orig_x)
            dis_fake = self.dis(torch.cat((orig_x, gen_y.detach()), 1))
            fake_labels = Variable(self.cudafy(
                torch.zeros(dis_fake.size())
            ))
            dis_fake_loss = self.criterion_bce(
                dis_fake, fake_labels)

            # Update weights
            dis_loss = dis_real_loss + dis_fake_loss
            dis_loss.backward()

            self.dis_optim.step()
            self.reset_gradients()

            """ Generator """
            self.dis.volatile = True
            dis_real = self.dis(torch.cat((orig_x, gen_y), 1))
            real_labels = Variable(self.cudafy(
                torch.ones(dis_real.size())
            ))
            gen_loss = self.criterion_bce(dis_real, real_labels) + \
                self.lamb * self.criterion_l1(gen_y, orig_y)
            gen_loss.backward()
            self.gen_optim.step()

            # Pycrayon or nah
            if self.crayon:
                self.logger.add_scalar_value('pix2pix_gen_loss', gen_loss.data[0])
                self.logger.add_scalar_value('pix2pix_dis_loss', dis_loss.data[0])

            if idx % 50 == 0:
                tqdm.write('Epoch: {} [{}/{}]\t'
                           'D Loss: {:.4f}\t'
                           'G Loss: {:.4f}'.format(
                               c_epoch, idx, max_idx, dis_loss.data[0], gen_loss.data[0]
                           ))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号