def plot_trajectories(src_sent, src_encoding, idx):
# encoding is (time_steps, hidden_dim)
#pca = PCA(n_components=1)
#pca_result = pca.fit_transform(src_encoding)
times = np.arange(src_encoding.shape[0])
plt.plot(times, src_encoding)
plt.title(" ".join(src_sent))
plt.xlabel('timestep')
plt.ylabel('trajectories')
plt.savefig("misc_hidden_cell_trajectories_"+str(idx), bbox_inches="tight")
plt.close()
评论列表
文章目录