def plot(m, Xtrain, ytrain):
xx = np.linspace(-0.5, 1.5, 100)[:, None]
mean, var = m.predict_y(xx)
mean = np.reshape(mean, (xx.shape[0], 1))
var = np.reshape(var, (xx.shape[0], 1))
if isinstance(m, aep.SDGPR):
zu = m.sgp_layers[0].zu
elif isinstance(m, vfe.SGPR_collapsed):
zu = m.zu
else:
zu = m.sgp_layer.zu
mean_u, var_u = m.predict_f(zu)
plt.figure()
plt.plot(Xtrain, ytrain, 'kx', mew=2)
plt.plot(xx, mean, 'b', lw=2)
# pdb.set_trace()
plt.fill_between(
xx[:, 0],
mean[:, 0] - 2 * np.sqrt(var[:, 0]),
mean[:, 0] + 2 * np.sqrt(var[:, 0]),
color='blue', alpha=0.2)
plt.errorbar(zu, mean_u, yerr=2 * np.sqrt(var_u), fmt='ro')
plt.xlim(-0.1, 1.1)
评论列表
文章目录