def update_core(self):
xp = self.gen.xp
self._iter += 1
opt_g = self.get_optimizer('gen')
opt_d = self.get_optimizer('dis')
data_z = self.get_latent_code_batch()
data_x = self.get_real_image_batch()
x_fake = self.gen(Variable(data_z))
dis_fake = self.dis(x_fake)
loss_gen = loss_func_lsgan_dis_real(dis_fake)
chainer.report({'loss': loss_gen}, self.gen)
opt_g.zero_grads()
loss_gen.backward()
opt_g.update()
x_fake.unchain_backward()
x_real = Variable(data_x)
dis_real = self.dis(x_real)
loss_dis = loss_func_lsgan_dis_real(dis_real) + loss_func_lsgan_dis_fake(dis_fake)
opt_d.zero_grads()
loss_dis.backward()
opt_d.update()
chainer.report({'loss': loss_dis}, self.dis)
评论列表
文章目录