plotting.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:kvae 作者: simonkamronn 项目源码 文件源码
def plot_auxiliary(all_vars, filename, table_size=4):
    # All variables need to be (batch_size, sequence_length, dimension)
    for i, a in enumerate(all_vars):
        if a.ndim == 2:
            all_vars[i] = np.expand_dims(a, 0)

    dim = all_vars[0].shape[-1]
    if dim == 2:
        f, ax = plt.subplots(table_size, table_size, sharex='col', sharey='row', figsize=[12, 12])
        idx = 0
        for x in range(table_size):
            for y in range(table_size):
                for a in all_vars:
                    # Loop over the batch dimension
                    ax[x, y].plot(a[idx, :, 0], a[idx, :, 1], linestyle='-', marker='o', markersize=3)
                    # Plot starting point of the trajectory
                    ax[x, y].plot(a[idx, 0, 0], a[idx, 0, 1], 'r.', ms=12)
                idx += 1
        # plt.show()
        plt.savefig(filename, format='png', bbox_inches='tight', dpi=80)
        plt.close()
    else:
        df_list = []
        for i, a in enumerate(all_vars):
            df = pd.DataFrame(all_vars[i].reshape(-1, dim))
            df['class'] = i
            df_list.append(df)

        df_all = pd.concat(df_list)
        sns_plot = sns.pairplot(df_all, hue="class", vars=range(dim))
        sns_plot.savefig(filename)
    plt.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号