def update_core(self):
xp = self.gen.xp
self._iter += 1
opt_g = self.get_optimizer('gen')
opt_d = self.get_optimizer('dis')
data_z = self.get_latent_code_batch()
data_tag = self.get_fake_tag_batch()
data_x, data_real_tag = self.get_real_image_batch()
x_fake = self.gen(F.concat([Variable(data_z),Variable(data_tag)]))
dis_fake, dis_g_class = self.dis(x_fake)
data_tag[data_tag < 0] = 0.0
loss_g_class =loss_sigmoid_cross_entropy_with_logits(dis_g_class, data_tag)
#print(loss_g_class.data)
loss_gen = self._lambda_adv * loss_func_dcgan_dis_real(dis_fake) + loss_g_class
chainer.report({'loss': loss_gen, 'loss_c': loss_g_class}, self.gen)
opt_g.zero_grads()
loss_gen.backward()
opt_g.update()
x_fake.unchain_backward()
std_data_x = xp.std(data_x, axis=0, keepdims=True)
rnd_x = xp.random.uniform(0, 1, data_x.shape).astype("f")
x_perturbed = Variable(data_x + 0.5*rnd_x*std_data_x)
x_real = Variable(data_x)
dis_real, dis_d_class = self.dis(x_real)
dis_perturbed, _ = self.dis(x_perturbed, retain_forward=True)
g = Variable(xp.ones_like(dis_perturbed.data))
grad = self.dis.differentiable_backward(g)
grad_l2 = F.sqrt(F.sum(grad**2, axis=(1, 2, 3)))
loss_gp = self._lambda_gp * loss_l2(grad_l2, 1.0)
loss_d_class = loss_sigmoid_cross_entropy_with_logits(dis_d_class, data_real_tag)
loss_dis = self._lambda_adv * ( loss_func_dcgan_dis_real(dis_real) + \
loss_func_dcgan_dis_fake(dis_fake) )+ \
loss_d_class + \
loss_gp
opt_d.zero_grads()
loss_dis.backward()
opt_d.update()
chainer.report({'loss': loss_dis, 'loss_gp': loss_gp, 'loss_c': loss_d_class}, self.dis)
评论列表
文章目录