utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:ngraph 作者: NervanaSystems 项目源码 文件源码
def save_plot(niters, loss, args):
    print('Saving training loss-iteration figure...')
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        name = 'Train-{}_hs-{}_lr-{}_bs-{}'.format(args.train_file, args.hs,
                                                   args.lr, args.batch_size)
        plt.title(name)
        plt.plot(niters, loss)
        plt.xlabel('iteration')
        plt.ylabel('loss')
        plt.savefig(name + '.jpg')
        print('{} saved!'.format(name + '.jpg'))

    except ImportError:
        print('matplotlib not installed and no figure is saved.')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号