def backward(self, grad_output):
z, mu, sig = self.saved_tensors
p = st.norm(mu.cpu().numpy(),sig.cpu().numpy())
pz = torch.DoubleTensor(p.pdf(z.cpu().numpy())).cuda()
dz = -(self.gamma_under + self.gamma_over) * (z-mu) / (sig**2) * pz
dmu = -dz
dsig = (self.gamma_under + self.gamma_over) * ((z-mu)**2 - sig**2) / \
(sig**3) * pz
return grad_output * dz, grad_output * dmu, grad_output * dsig
评论列表
文章目录