def setUp(self):
pyro.clear_param_store()
def model():
mu = pyro.sample("mu", Normal(Variable(torch.zeros(1)),
Variable(torch.ones(1))))
xd = Normal(mu, Variable(torch.ones(1)), batch_size=50)
pyro.observe("xs", xd, self.data)
return mu
def guide():
return pyro.sample("mu", Normal(Variable(torch.zeros(1)),
Variable(torch.ones(1))))
# data
self.data = Variable(torch.zeros(50, 1))
self.mu_mean = Variable(torch.zeros(1))
self.mu_stddev = torch.sqrt(Variable(torch.ones(1)) / 51.0)
# model and guide
self.model = model
self.guide = guide
评论列表
文章目录