def load_toy_data(n_samples=1000, dtype='float32'):
print('creating Melbourne toy dataset as an inverse problem.')
print('There are two (if not more) Melbournes, one in Australia and one in Florida, USA')
mlb_fl_latlon_mean = np.array((28.0836, -80.6081))
mlb_au_latlon_mean = np.array((-37.8136, 144.9631))
cov = np.array([[1, 0], [0, 1]])
# create bivariate gaussians to sample from the means (with variances 1, 1 and correlation 0) Melb, Au samples are two times of Melb, FL
mlb_fl_samples = np.random.multivariate_normal(mean=mlb_fl_latlon_mean, cov=cov, size=n_samples).astype(dtype)
mlb_au_samples = np.random.multivariate_normal(mean=mlb_au_latlon_mean, cov=cov, size=n_samples * 2).astype(dtype)
# plt.scatter(mlb_fl_samples[:, 0], mlb_fl_samples[:, 1], c='blue', s=1)
# plt.scatter(mlb_au_samples[:, 0], mlb_au_samples[:, 1], c='red', s=1)
# plt.show()
X = sp.sparse.csr_matrix(np.random.uniform(-0.1, 0.1, size=(n_samples * 3, 2)) + np.array([1, 0])).astype(dtype)
Y = np.vstack((mlb_fl_samples, mlb_au_samples))
# shuffle X and Y
indices = np.arange(n_samples * 3)
np.random.shuffle(indices)
X = X[indices]
Y = Y[indices]
n_train_samples = 2 * n_samples
n_dev_samples = n_samples / 2
n_test_samples = 3 * n_samples - n_train_samples - n_dev_samples
X_train = X[0:n_train_samples, :]
X_dev = X[n_train_samples:n_train_samples + n_dev_samples, :]
X_test = X[n_train_samples + n_dev_samples:n_train_samples + n_dev_samples + n_test_samples, :]
Y_train = Y[0:n_train_samples, :]
Y_dev = Y[n_train_samples:n_train_samples + n_dev_samples, :]
Y_test = Y[n_train_samples + n_dev_samples:n_train_samples + n_dev_samples + n_test_samples, :]
U_train = [i for i in range(n_train_samples)]
U_dev = [i for i in range(n_train_samples, n_train_samples + n_dev_samples)]
U_test = [i for i in range(n_train_samples + n_dev_samples, n_train_samples + n_dev_samples + n_test_samples)]
userLocation = {}
for i in range(0, 3 * n_samples):
lat, lon = Y[i, :]
userLocation[i] = str(lat) + ',' + str(lon)
data = (X_train, Y_train, X_dev, Y_dev, X_test, Y_test, U_train, U_dev, U_test, None, None, userLocation, None)
return data
评论列表
文章目录