def update_core(self):
gen_optimizer = self.get_optimizer('gen')
dis_optimizer = self.get_optimizer('dis')
batch = self.get_iterator('main').next()
x_real = Variable(self.converter(batch, self.device)) / 255.
xp = chainer.cuda.get_array_module(x_real.data)
gen, dis = self.gen, self.dis
batchsize = len(batch)
y_real = dis(x_real)
z = Variable(xp.asarray(gen.make_hidden(batchsize)))
x_fake = gen(z)
y_fake = dis(x_fake)
dis_optimizer.update(self.loss_dis, dis, y_fake, y_real)
gen_optimizer.update(self.loss_gen, gen, y_fake)
评论列表
文章目录