def backward(self, grad_output):
z, log_phi_z = self.saved_tensors
log_phi_z_grad = z.new().resize_as_(z).zero_()
z_is_small = z.lt(-1)
z_is_not_small = 1 - z_is_small
if z_is_small.sum() > 0:
log_phi_z_grad[z_is_small] = torch.abs(self.denominator.div(self.numerator)).mul(math.sqrt(2 / math.pi))
exp = z[z_is_not_small].pow(2) \
.div(-2) \
.sub(log_phi_z[z_is_not_small]) \
.add(math.log(0.5))
log_phi_z_grad[z_is_not_small] = torch.exp(exp).mul(math.sqrt(2 / math.pi))
return log_phi_z_grad.mul(grad_output)
评论列表
文章目录