def run_pinwheel():
def make_pinwheel(radial_std, tangential_std, num_classes, num_per_class, rate,
rs=np.random.RandomState(0)):
"""Based on code by Ryan P. Adams."""
rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
features = rs.randn(num_classes * num_per_class, 2) \
* np.array([radial_std, tangential_std])
features[:, 0] += 1
labels = np.repeat(np.arange(num_classes), num_per_class)
angles = rads[labels] + rate * np.exp(features[:, 0])
rotations = np.stack([np.cos(angles), -np.sin(angles),
np.sin(angles), np.cos(angles)])
rotations = np.reshape(rotations.T, (-1, 2, 2))
return np.einsum('ti,tij->tj', features, rotations)
# create dataset
print "creating dataset..."
Y = make_pinwheel(radial_std=0.3, tangential_std=0.05, num_classes=3,
num_per_class=50, rate=0.4)
# inference
print "inference ..."
M = 20
D = 2
lvm = vfe.SGPLVM(Y, D, M, lik='Gaussian')
lvm.optimise(method='L-BFGS-B')
mx, vx = lvm.get_posterior_x()
fig = plt.figure()
ax = fig.add_subplot(121)
ax.plot(Y[:, 0], Y[:, 1], 'bx')
ax = fig.add_subplot(122)
ax.errorbar(mx[:, 0], mx[:, 1], xerr=np.sqrt(
vx[:, 0]), yerr=np.sqrt(vx[:, 1]), fmt='xk')
plt.show()
评论列表
文章目录