def run_mnist():
np.random.seed(42)
# import dataset
f = gzip.open('./tmp/data/mnist.pkl.gz', 'rb')
(x_train, t_train), (x_valid, t_valid), (x_test, t_test) = cPickle.load(f)
f.close()
Y = x_train[:100, :]
labels = t_train[:100]
Y[Y < 0.5] = -1
Y[Y > 0.5] = 1
# inference
print "inference ..."
M = 30
D = 2
# lvm = aep.SGPLVM(Y, D, M, lik='Gaussian')
lvm = aep.SGPLVM(Y, D, M, lik='Probit')
# lvm.train(alpha=0.5, no_epochs=10, n_per_mb=100, lrate=0.1, fixed_params=['sn'])
lvm.optimise(method='L-BFGS-B', alpha=0.1)
plt.figure()
mx, vx = lvm.get_posterior_x()
zu = lvm.sgp_layer.zu
plt.scatter(mx[:, 0], mx[:, 1], c=labels)
plt.plot(zu[:, 0], zu[:, 1], 'ko')
nx = ny = 30
x_values = np.linspace(-5, 5, nx)
y_values = np.linspace(-5, 5, ny)
sx = 28
sy = 28
canvas = np.empty((sx * ny, sy * nx))
for i, yi in enumerate(x_values):
for j, xi in enumerate(y_values):
z_mu = np.array([[xi, yi]])
x_mean, x_var = lvm.predict_f(z_mu)
t = x_mean / np.sqrt(1 + x_var)
Z = 0.5 * (1 + special.erf(t / np.sqrt(2)))
canvas[(nx - i - 1) * sx:(nx - i) * sx, j *
sy:(j + 1) * sy] = Z.reshape(sx, sy)
plt.figure(figsize=(8, 10))
Xi, Yi = np.meshgrid(x_values, y_values)
plt.imshow(canvas, origin="upper", cmap="gray")
plt.tight_layout()
plt.show()
评论列表
文章目录