def sample_next_state(self):
def adjust_temp(pi_pdf):
pi_pdf = np.log(pi_pdf)/hp.temperature
pi_pdf -= pi_pdf.max()
pi_pdf = np.exp(pi_pdf)
pi_pdf /= pi_pdf.sum()
return pi_pdf
# get mixture indice:
pi = self.pi.data[0,0,:].cpu().numpy()
pi = adjust_temp(pi)
pi_idx = np.random.choice(hp.M, p=pi)
# get pen state:
q = self.q.data[0,0,:].cpu().numpy()
q = adjust_temp(q)
q_idx = np.random.choice(3, p=q)
# get mixture params:
mu_x = self.mu_x.data[0,0,pi_idx]
mu_y = self.mu_y.data[0,0,pi_idx]
sigma_x = self.sigma_x.data[0,0,pi_idx]
sigma_y = self.sigma_y.data[0,0,pi_idx]
rho_xy = self.rho_xy.data[0,0,pi_idx]
x,y = sample_bivariate_normal(mu_x,mu_y,sigma_x,sigma_y,rho_xy,greedy=False)
next_state = torch.zeros(5)
next_state[0] = x
next_state[1] = y
next_state[q_idx+2] = 1
if use_cuda:
return Variable(next_state.cuda()).view(1,1,-1),x,y,q_idx==1,q_idx==2
else:
return Variable(next_state).view(1,1,-1),x,y,q_idx==1,q_idx==2
评论列表
文章目录