def progress_func(epoch, loss, accuracy, valid_loss, valid_accuracy, test_loss, test_accuracy):
print 'epoch: {} done'.format(epoch)
print('train mean loss={}, accuracy={}'.format(loss, accuracy))
if valid_loss is not None and valid_accuracy is not None:
print('valid mean loss={}, accuracy={}'.format(valid_loss, valid_accuracy))
if test_loss is not None and test_accuracy is not None:
print('test mean loss={}, accuracy={}'.format(test_loss, test_accuracy))
if valid_accuracy < progress_state['valid_accuracy']:
serializers.save_npz(args.output, net)
progress_state['valid_accuracy'] = valid_accuracy
progress_state['test_accuracy'] = test_accuracy
if epoch % args.save_iter == 0:
base, ext = os.path.splitext(args.output)
serializers.save_npz('{0}_{1:04d}{2}'.format(base, epoch, ext), net)
if args.lr_decay_iter > 0 and epoch % args.lr_decay_iter == 0:
optimizer.alpha *= args.lr_decay_ratio
评论列表
文章目录