def setUp(self):
# simple Gaussian-emission HMM
def model():
p_latent = pyro.param("p1", Variable(torch.Tensor([[0.7], [0.3]])))
p_obs = pyro.param("p2", Variable(torch.Tensor([[0.9], [0.1]])))
latents = [Variable(torch.ones(1, 1))]
observes = []
for t in range(self.model_steps):
latents.append(
pyro.sample("latent_{}".format(str(t)),
Bernoulli(torch.index_select(p_latent, 0, latents[-1].view(-1).long()))))
observes.append(
pyro.observe("observe_{}".format(str(t)),
Bernoulli(torch.index_select(p_obs, 0, latents[-1].view(-1).long())),
self.data[t]))
return torch.sum(torch.cat(latents))
self.model_steps = 3
self.data = [pyro.ones(1, 1) for _ in range(self.model_steps)]
self.model = model
评论列表
文章目录