plotting_utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ip-avsr 作者: lzuwei 项目源码 文件源码
def plot_validation_cost(train_error, val_error, class_rate=None, savefilename=None):
    epochs = range(len(train_error))
    fig, ax1 = plt.subplots()
    ax1.plot(epochs, train_error, label='train error')
    ax1.plot(epochs, val_error, label='validation error')
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('cost')
    plt.title('Validation Cost')
    lines = ax1.get_lines()
    # Shrink current axis's height by 10% on the bottom
    box = ax1.get_position()
    ax1.set_position([box.x0, box.y0 + box.height * 0.1,
                      box.width, box.height * 0.9])
    if class_rate is not None:
        ax2 = plt.twinx(ax1)
        ax2.plot(epochs, class_rate, label='classification rate', color='r')
        ax2.set_ylabel('classification rate')
        lines.extend(ax2.get_lines())
        ax2.set_position([box.x0, box.y0 + box.height * 0.1,
                          box.width, box.height * 0.9])

    labels = [l.get_label() for l in lines]
    # Put a legend below current axis
    ax1.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, -0.05),
               fancybox=False, shadow=False, ncol=5)
    # ax1.legend(lines, labels, loc='lower right')
    if savefilename:
        plt.savefig(savefilename)
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号