def calc_gradient_penalty(D, real_data, fake_data):
"""Calculatge gradient penalty for WGAN-GP."""
alpha = torch.rand(params.batch_size, 1)
alpha = alpha.expand(real_data.size())
alpha = make_cuda(alpha)
interpolates = make_variable(alpha * real_data + ((1 - alpha) * fake_data))
interpolates.requires_grad = True
disc_interpolates = D(interpolates)
gradients = grad(outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=make_cuda(
torch.ones(disc_interpolates.size())),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
gradient_penalty = params.penalty_lambda * \
((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
评论列表
文章目录