def plot_auxiliary(all_vars, filename, table_size=4):
# All variables need to be (batch_size, sequence_length, dimension)
for i, a in enumerate(all_vars):
if a.ndim == 2:
all_vars[i] = np.expand_dims(a, 0)
dim = all_vars[0].shape[-1]
if dim == 2:
f, ax = plt.subplots(table_size, table_size, sharex='col', sharey='row', figsize=[12, 12])
idx = 0
for x in range(table_size):
for y in range(table_size):
for a in all_vars:
# Loop over the batch dimension
ax[x, y].plot(a[idx, :, 0], a[idx, :, 1], linestyle='-', marker='o', markersize=3)
# Plot starting point of the trajectory
ax[x, y].plot(a[idx, 0, 0], a[idx, 0, 1], 'r.', ms=12)
idx += 1
# plt.show()
plt.savefig(filename, format='png', bbox_inches='tight', dpi=80)
plt.close()
else:
df_list = []
for i, a in enumerate(all_vars):
df = pd.DataFrame(all_vars[i].reshape(-1, dim))
df['class'] = i
df_list.append(df)
df_all = pd.concat(df_list)
sns_plot = sns.pairplot(df_all, hue="class", vars=range(dim))
sns_plot.savefig(filename)
plt.close()
评论列表
文章目录