lin_cos_exp.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:geepee 作者: thangbui 项目源码 文件源码
def plot_latent(model, y, plot_title=''):
    # make prediction on some test inputs
    N_test = 300
    C = model.get_hypers()['C_emission'][0, 0]
    x_test = np.linspace(-10, 8, N_test) / C
    x_test = np.reshape(x_test, [N_test, 1])
    if isinstance(model, aep.SGPSSM) or isinstance(model, vfe.SGPSSM):
        zu = model.dyn_layer.zu
    else:
        zu = model.sgp_layer.zu
    mu, vu = model.predict_f(zu)
    # mu, Su = model.dyn_layer.mu, model.dyn_layer.Su
    mf, vf = model.predict_f(x_test)
    my, vy = model.predict_y(x_test)
    # plot function
    fig = plt.figure()
    ax = fig.add_subplot(111)
    # ax.plot(x_test[:,0], kink_true(x_test[:,0]), '-', color='k')
    ax.plot(C*x_test[:,0], my[:,0], '-', color='r', label='y')
    ax.fill_between(
        C*x_test[:,0], 
        my[:,0] + 2*np.sqrt(vy[:, 0]), 
        my[:,0] - 2*np.sqrt(vy[:, 0]), 
        alpha=0.2, edgecolor='r', facecolor='r')
    ax.plot(
        y[0:model.N-1], 
        y[1:model.N], 
        'r+', alpha=0.5)
    mx, vx = model.get_posterior_x()
    ax.set_xlabel(r'$x_{t-1}$')
    ax.set_ylabel(r'$x_{t}$')
    plt.title(plot_title)
    plt.savefig('/tmp/lincos_'+plot_title+'.png')

# generate a dataset from the lincos function above
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号