def plot_latent(model, y, plot_title=''):
# make prediction on some test inputs
N_test = 300
C = model.get_hypers()['C_emission'][0, 0]
x_test = np.linspace(-10, 8, N_test) / C
x_test = np.reshape(x_test, [N_test, 1])
if isinstance(model, aep.SGPSSM) or isinstance(model, vfe.SGPSSM):
zu = model.dyn_layer.zu
else:
zu = model.sgp_layer.zu
mu, vu = model.predict_f(zu)
# mu, Su = model.dyn_layer.mu, model.dyn_layer.Su
mf, vf = model.predict_f(x_test)
my, vy = model.predict_y(x_test)
# plot function
fig = plt.figure()
ax = fig.add_subplot(111)
# ax.plot(x_test[:,0], kink_true(x_test[:,0]), '-', color='k')
ax.plot(C*x_test[:,0], my[:,0], '-', color='r', label='y')
ax.fill_between(
C*x_test[:,0],
my[:,0] + 2*np.sqrt(vy[:, 0]),
my[:,0] - 2*np.sqrt(vy[:, 0]),
alpha=0.2, edgecolor='r', facecolor='r')
ax.plot(
y[0:model.N-1],
y[1:model.N],
'r+', alpha=0.5)
mx, vx = model.get_posterior_x()
ax.set_xlabel(r'$x_{t-1}$')
ax.set_ylabel(r'$x_{t}$')
plt.title(plot_title)
plt.savefig('/tmp/lincos_'+plot_title+'.png')
# generate a dataset from the lincos function above
评论列表
文章目录