def optimize_gan_hkl(self, model, lam1=0.00001):
"""
optimizer for hkl packaged dataset.
Returns the updates for discirminator & generator and computed costs for the model.
"""
i = T.iscalar('i');
lr = T.fscalar('lr');
Xu = T.fmatrix('X');
cost_disc = model.cost_dis(Xu, self.batch_sz) \
+ lam1 * model.dis_network.weight_decay_l2()
gparams_dis = T.grad(cost_disc, model.dis_network.params)
cost_gen = model.cost_gen(self.batch_sz)
gparams_gen = T.grad(cost_gen, model.gen_network.params)
updates_dis = self.ADAM(model.dis_network.params, gparams_dis, lr)
updates_gen = self.ADAM(model.gen_network.params, gparams_gen, lr)
discriminator_update = theano.function([Xu, theano.Param(lr,default=self.epsilon_dis)],\
outputs=cost_disc, updates=updates_dis)
generator_update = theano.function([theano.Param(lr,default=self.epsilon_gen)],\
outputs=cost_gen, updates=updates_gen)
get_valid_cost = theano.function([Xu], outputs=[cost_disc, cost_gen])
get_test_cost = theano.function([Xu], outputs=[cost_disc, cost_gen])
return discriminator_update, generator_update, get_valid_cost, get_test_cost
评论列表
文章目录