optimize_gan.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:GRAN 作者: jiwoongim 项目源码 文件源码
def optimize_gan(self, model, train_set, valid_set, test_set, lam1=0.00001):
        """
        optimizer for non packaged dataset, 
        returning updates for discriminator & generator, as well as the computed costs.
        """

        i = T.iscalar('i'); lr = T.fscalar('lr');
        Xu = T.matrix('X'); 
        cost_disc   = model.cost_dis(Xu, self.batch_sz) \
                     + lam1 * model.dis_network.weight_decay_l2() 

        gparams_dis = T.grad(cost_disc, model.dis_network.params)

        cost_gen    = model.cost_gen(self.batch_sz)
        gparams_gen = T.grad(cost_gen, model.gen_network.params)


        updates_dis = self.ADAM(model.dis_network.params, gparams_dis, lr)
        updates_gen = self.ADAM(model.gen_network.params, gparams_gen, lr)

        discriminator_update = theano.function([i, theano.Param(lr,default=self.epsilon_dis)],\
                outputs=cost_disc, updates=updates_dis,\
                givens={Xu:train_set[0][i*self.batch_sz:(i+1)*self.batch_sz]})

        generator_update = theano.function([theano.Param(lr,default=self.epsilon_gen)],\
                outputs=cost_gen, updates=updates_gen)

        get_valid_cost   = theano.function([i], outputs=[cost_disc, cost_gen],\
                givens={Xu:valid_set[0][i*self.batch_sz:(i+1)*self.batch_sz]})

        get_test_cost   = theano.function([i], outputs=[cost_disc, cost_gen],\
                givens={Xu:test_set[0][i*self.batch_sz:(i+1)*self.batch_sz]})

        return discriminator_update, generator_update, get_valid_cost, get_test_cost
评论列表


问题


面经


文章

微信
公众号

扫码关注公众号