def plot_results(self, plot_path=None, ylim=None):
# Plotting parameters
label_size = 18
mpl.rcParams['xtick.labelsize'] = label_size
mpl.rcParams['ytick.labelsize'] = label_size
plot_params = dict()
plot_params['ms'] = 10
plot_params['linewidth'] = 3
# Plot training bound on the perplexity
f = plt.figure(figsize=[12, 12])
plt.errorbar(self.epochs_eval, [self.elbo_seq_train_all[i] for i in self.epochs_eval],
[self.elbo_seq_train_all_std[i] for i in self.epochs_eval], marker='d', color='b',
label='Train', **plot_params)
plt.plot(self.epochs_eval, self.elbo_seq_valid_all, "-rh", label="Valid", **plot_params)
plt.plot(self.epochs_eval, self.elbo_seq_test_all, "-k^", label="Test", **plot_params)
plt.xlabel('Epochs', fontsize=20)
# plt.ylabel('log()', fontsize=20)
plt.grid('on')
plt.title('ELBO sequence', fontsize=24, y=1.01)
plt.legend(loc="upper right", handlelength=3, fontsize=20)
if ylim is not None:
plt.ylim(ylim)
if plot_path is not None:
plt.savefig(plot_path + "_epochs.png", format='png', bbox_inches='tight', dpi=200)
plt.close(f)
# Plot norm of the updates
f = plt.figure(figsize=[12, 12])
plt.errorbar(self.epochs_eval, [self.mean_norm_all[i] for i in self.epochs_eval],
[self.std_norm_all[i] for i in self.epochs_eval], marker='d', color='m', **plot_params)
plt.grid('on')
plt.title('Norm of the updates', fontsize=24, y=1.01)
if plot_path is not None:
plt.savefig(plot_path + "_norm.png", format='png', bbox_inches='tight', dpi=200)
plt.close(f)
评论列表
文章目录