optimize_gan.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:GRAN 作者: jiwoongim 项目源码 文件源码
def optimize_gan_hkl(self, model, lam1=0.00001):
        """
        optimizer for hkl packaged dataset. 
        Returns the updates for discirminator & generator and computed costs for the model.
        """

        i = T.iscalar('i'); 
        lr = T.fscalar('lr');
        Xu = T.fmatrix('X'); 

        cost_disc   = model.cost_dis(Xu, self.batch_sz) \
                                + lam1 * model.dis_network.weight_decay_l2()
        gparams_dis = T.grad(cost_disc, model.dis_network.params)

        cost_gen    = model.cost_gen(self.batch_sz) 
        gparams_gen = T.grad(cost_gen, model.gen_network.params)


        updates_dis = self.ADAM(model.dis_network.params, gparams_dis, lr)
        updates_gen = self.ADAM(model.gen_network.params, gparams_gen, lr)


        discriminator_update = theano.function([Xu, theano.Param(lr,default=self.epsilon_dis)],\
                                    outputs=cost_disc, updates=updates_dis)

        generator_update = theano.function([theano.Param(lr,default=self.epsilon_gen)],\
                outputs=cost_gen, updates=updates_gen)

        get_valid_cost   = theano.function([Xu], outputs=[cost_disc, cost_gen])

        get_test_cost   = theano.function([Xu], outputs=[cost_disc, cost_gen])

        return discriminator_update, generator_update, get_valid_cost, get_test_cost
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号