updater.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainer-gan-experiments 作者: Aixile 项目源码 文件源码
def update_core(self):
        xp = self.gen.xp
        self._iter += 1

        opt_d = self.get_optimizer('dis')
        for i in range(self._dis_iter):
            d_fake = self.get_fake_image_batch()
            d_real = self.get_real_image_batch()

            y_fake = self.dis(Variable(d_fake), test=False)
            y_real = self.dis(Variable(d_real), test=False)

            w1 = F.average(y_fake-y_real)

            loss_dis = w1

            if self._mode == 'gp':
                eta = np.random.rand()
                c = (d_real * eta + (1.0 - eta) * d_fake).astype('f')
                y = self.dis(Variable(c), test=False, retain_forward=True)

                g = xp.ones_like(y.data)
                grad_c = self.dis.differentiable_backward(Variable(g))
                grad_c_l2 = F.sqrt(F.sum(grad_c**2, axis=(1, 2, 3)))

                loss_gp = loss_l2(grad_c_l2, 1.0)

                loss_dis += self._lambda_gp * loss_gp

            opt_d.zero_grads()
            loss_dis.backward()
            opt_d.update()

            if self._mode == 'clip':
                self.dis.clip()

        chainer.report({'loss': loss_dis,'loss_w1': w1}, self.dis)

        z_in = self.get_latent_code_batch()
        x_out = self.gen(Variable(z_in), test=False)

        opt_g = self.get_optimizer('gen')
        y_fake = self.dis(x_out, test=False)
        loss_gen = - F.average(y_fake)

        chainer.report({'loss': loss_gen}, self.gen)

        opt_g.zero_grads()
        loss_gen.backward()
        opt_g.update()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号