plotting.py 文件源码

python
阅读 66 收藏 0 点赞 0 评论 0

项目:kvae 作者: simonkamronn 项目源码 文件源码
def plot_trajectory_uncertainty(true, gen, filter, smooth, filename):
    sequences, timesteps, h, w = true.shape

    errors = dict(Generated=list(), Filtered=list(), Smoothed=list())
    for label, var in zip(('Generated', 'Filtered', 'Smoothed'), (gen, filter, smooth)):
        for step in range(timesteps):
            errors[label].append(hamming(true[:, step].ravel() > 0.5, var[:, step].ravel() > 0.5))

        plt.plot(np.linspace(1, timesteps, num=timesteps).astype(int), errors[label], linewidth=3, ms=20, label=label)
    plt.xlabel('Steps', fontsize=20)
    plt.ylabel('Hamming distance', fontsize=20)
    plt.legend(fontsize=20)
    plt.savefig(filename)
    plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号