updater.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:chainer-gan-experiments 作者: Aixile 项目源码 文件源码
def update_core(self):
        xp = self.gen.xp
        self._iter += 1

        opt_g = self.get_optimizer('gen')
        opt_d = self.get_optimizer('dis')

        data_z = self.get_latent_code_batch()
        data_x = self.get_real_image_batch()

        x_fake = self.gen(Variable(data_z))
        dis_fake = self.dis(x_fake)

        loss_gen = loss_func_lsgan_dis_real(dis_fake)
        chainer.report({'loss': loss_gen}, self.gen)

        opt_g.zero_grads()
        loss_gen.backward()
        opt_g.update()

        x_fake.unchain_backward()
        x_real = Variable(data_x)
        dis_real = self.dis(x_real)
        loss_dis = loss_func_lsgan_dis_real(dis_real) + loss_func_lsgan_dis_fake(dis_fake)

        opt_d.zero_grads()
        loss_dis.backward()
        opt_d.update()

        chainer.report({'loss': loss_dis}, self.dis)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号