def guide(data):
w_mu = Variable(torch.randn(p, 1).type_as(data.data), requires_grad=True)
w_log_sig = Variable((-3.0 * torch.ones(p, 1) + 0.05 * torch.randn(p, 1)).type_as(data.data), requires_grad=True)
b_mu = Variable(torch.randn(1).type_as(data.data), requires_grad=True)
b_log_sig = Variable((-3.0 * torch.ones(1) + 0.05 * torch.randn(1)).type_as(data.data), requires_grad=True)
# register learnable params in the param store
mw_param = pyro.param("guide_mean_weight", w_mu)
sw_param = softplus(pyro.param("guide_log_sigma_weight", w_log_sig))
mb_param = pyro.param("guide_mean_bias", b_mu)
sb_param = softplus(pyro.param("guide_log_sigma_bias", b_log_sig))
# gaussian guide distributions for w and b
w_dist = Normal(mw_param, sw_param)
b_dist = Normal(mb_param, sb_param)
dists = {'linear.weight': w_dist, 'linear.bias': b_dist}
# overloading the parameters in the module with random samples from the guide distributions
lifted_module = pyro.random_module("module", regression_model, dists)
# sample a regressor
return lifted_module()
# instantiate optim and inference objects
评论列表
文章目录