def plot_trajectory_uncertainty(true, gen, filter, smooth, filename):
sequences, timesteps, h, w = true.shape
errors = dict(Generated=list(), Filtered=list(), Smoothed=list())
for label, var in zip(('Generated', 'Filtered', 'Smoothed'), (gen, filter, smooth)):
for step in range(timesteps):
errors[label].append(hamming(true[:, step].ravel() > 0.5, var[:, step].ravel() > 0.5))
plt.plot(np.linspace(1, timesteps, num=timesteps).astype(int), errors[label], linewidth=3, ms=20, label=label)
plt.xlabel('Steps', fontsize=20)
plt.ylabel('Hamming distance', fontsize=20)
plt.legend(fontsize=20)
plt.savefig(filename)
plt.close()
评论列表
文章目录