mnist.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:mx-lsoftmax 作者: luoyetx 项目源码 文件源码
def train():
    ctx = mx.gpu(args.gpu) if args.gpu >=0 else mx.cpu()
    train = mx.io.MNISTIter(
                image='data/train-images-idx3-ubyte',
                label='data/train-labels-idx1-ubyte',
                input_shape=(1, 28, 28),
                mean_r=128,
                scale=1./128,
                batch_size=args.batch_size,
                shuffle=True)
    val = mx.io.MNISTIter(
                image='data/t10k-images-idx3-ubyte',
                label='data/t10k-labels-idx1-ubyte',
                input_shape=(1, 28, 28),
                mean_r=128,
                scale=1./128,
                batch_size=args.batch_size)
    symbol = get_symbol()
    mod = mx.mod.Module(
            symbol=symbol,
            context=ctx,
            data_names=('data',),
            label_names=('softmax_label',))
    num_examples = 60000
    epoch_size = int(num_examples / args.batch_size)
    optim_params = {
        'learning_rate': args.lr,
        'momentum': 0.9,
        'wd': 0.0005,
        'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=10*epoch_size, factor=0.1),
    }
    mod.fit(train_data=train,
            eval_data=val,
            eval_metric=mx.metric.Accuracy(),
            initializer=mx.init.Xavier(),
            optimizer='sgd',
            optimizer_params=optim_params,
            num_epoch=args.num_epoch,
            batch_end_callback=mx.callback.Speedometer(args.batch_size, 50),
            epoch_end_callback=mx.callback.do_checkpoint(args.model_prefix))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号