train_srnn_midi.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:srnn 作者: marcofraccaro 项目源码 文件源码
def plot_results(self, plot_path=None, ylim=None):

        # Plotting parameters
        label_size = 18
        mpl.rcParams['xtick.labelsize'] = label_size
        mpl.rcParams['ytick.labelsize'] = label_size
        plot_params = dict()
        plot_params['ms'] = 10
        plot_params['linewidth'] = 3

        # Plot training bound on the perplexity
        f = plt.figure(figsize=[12, 12])
        plt.errorbar(self.epochs_eval, [self.lower_bound_train_all[i] for i in self.epochs_eval],
                     [self.lower_bound_train_all_std[i] for i in self.epochs_eval], marker='d', color='b',
                     label='Train', **plot_params)
        plt.plot(self.epochs_eval, self.lower_bound_valid_all, "-rh", label="Valid", **plot_params)
        plt.plot(self.epochs_eval, self.lower_bound_test_all, "-k^", label="Test", **plot_params)
        plt.xlabel('Epochs', fontsize=20)
        plt.grid('on')
        plt.title('ELBO', fontsize=24, y=1.01)
        plt.legend(loc="lower right", handlelength=3, fontsize=20)
        if ylim is not None:
            plt.ylim(ylim)
        if plot_path is not None:
            plt.savefig(plot_path + "_epochs.png", format='png', bbox_inches='tight', dpi=200)
        plt.close(f)

        # Plot norm of the updates
        f = plt.figure(figsize=[12, 12])
        plt.errorbar(self.epochs_eval, [self.mean_norm_all[i] for i in self.epochs_eval],
                     [self.std_norm_all[i] for i in self.epochs_eval], marker='d', color='m', **plot_params)
        plt.grid('on')
        plt.title('Norm of the updates', fontsize=24, y=1.01)
        if plot_path is not None:
            plt.savefig(plot_path + "_norm.png", format='png', bbox_inches='tight', dpi=200)
        plt.close(f)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号