train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:chainer-cf-nade 作者: dsanno 项目源码 文件源码
def progress_func(epoch, loss, accuracy, valid_loss, valid_accuracy, test_loss, test_accuracy):
        print 'epoch: {} done'.format(epoch)
        print('train mean loss={}, accuracy={}'.format(loss, accuracy))
        if valid_loss is not None and valid_accuracy is not None:
            print('valid mean loss={}, accuracy={}'.format(valid_loss, valid_accuracy))
        if test_loss is not None and test_accuracy is not None:
            print('test mean loss={}, accuracy={}'.format(test_loss, test_accuracy))
        if valid_accuracy < progress_state['valid_accuracy']:
            serializers.save_npz(args.output, net)
            progress_state['valid_accuracy'] = valid_accuracy
            progress_state['test_accuracy'] = test_accuracy
        if epoch % args.save_iter == 0:
            base, ext = os.path.splitext(args.output)
            serializers.save_npz('{0}_{1:04d}{2}'.format(base, epoch, ext), net)
        if args.lr_decay_iter > 0 and epoch % args.lr_decay_iter == 0:
            optimizer.alpha *= args.lr_decay_ratio
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号