log_normal_cdf.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:gpytorch 作者: jrg365 项目源码 文件源码
def backward(self, grad_output):
        z, log_phi_z = self.saved_tensors
        log_phi_z_grad = z.new().resize_as_(z).zero_()

        z_is_small = z.lt(-1)
        z_is_not_small = 1 - z_is_small

        if z_is_small.sum() > 0:
            log_phi_z_grad[z_is_small] = torch.abs(self.denominator.div(self.numerator)).mul(math.sqrt(2 / math.pi))

        exp = z[z_is_not_small].pow(2) \
                               .div(-2) \
                               .sub(log_phi_z[z_is_not_small]) \
                               .add(math.log(0.5))

        log_phi_z_grad[z_is_not_small] = torch.exp(exp).mul(math.sqrt(2 / math.pi))

        return log_phi_z_grad.mul(grad_output)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号