bayesian_regression.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def guide(data):
    w_mu = Variable(torch.randn(p, 1).type_as(data.data), requires_grad=True)
    w_log_sig = Variable((-3.0 * torch.ones(p, 1) + 0.05 * torch.randn(p, 1)).type_as(data.data), requires_grad=True)
    b_mu = Variable(torch.randn(1).type_as(data.data), requires_grad=True)
    b_log_sig = Variable((-3.0 * torch.ones(1) + 0.05 * torch.randn(1)).type_as(data.data), requires_grad=True)
    # register learnable params in the param store
    mw_param = pyro.param("guide_mean_weight", w_mu)
    sw_param = softplus(pyro.param("guide_log_sigma_weight", w_log_sig))
    mb_param = pyro.param("guide_mean_bias", b_mu)
    sb_param = softplus(pyro.param("guide_log_sigma_bias", b_log_sig))
    # gaussian guide distributions for w and b
    w_dist = Normal(mw_param, sw_param)
    b_dist = Normal(mb_param, sb_param)
    dists = {'linear.weight': w_dist, 'linear.bias': b_dist}
    # overloading the parameters in the module with random samples from the guide distributions
    lifted_module = pyro.random_module("module", regression_model, dists)
    # sample a regressor
    return lifted_module()


# instantiate optim and inference objects
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号