def sample_gaussian_2d_train(mux, muy, sx, sy, corr, nodesPresent):
o_mux, o_muy, o_sx, o_sy, o_corr = mux, muy, sx, sy, corr
numNodes = mux.size()[0]
next_x = torch.zeros(numNodes)
next_y = torch.zeros(numNodes)
for node in range(numNodes):
if node not in nodesPresent:
continue
mean = [o_mux[node], o_muy[node]]
cov = [[o_sx[node]*o_sx[node], o_corr[node]*o_sx[node]*o_sy[node]], [o_corr[node]*o_sx[node]*o_sy[node], o_sy[node]*o_sy[node]]]
next_values = np.random.multivariate_normal(mean, cov, 1)
next_x[node] = next_values[0][0]
next_y[node] = next_values[0][1]
return next_x, next_y
评论列表
文章目录