def train():
ctx = mx.gpu(args.gpu) if args.gpu >=0 else mx.cpu()
train = mx.io.MNISTIter(
image='data/train-images-idx3-ubyte',
label='data/train-labels-idx1-ubyte',
input_shape=(1, 28, 28),
mean_r=128,
scale=1./128,
batch_size=args.batch_size,
shuffle=True)
val = mx.io.MNISTIter(
image='data/t10k-images-idx3-ubyte',
label='data/t10k-labels-idx1-ubyte',
input_shape=(1, 28, 28),
mean_r=128,
scale=1./128,
batch_size=args.batch_size)
symbol = get_symbol()
mod = mx.mod.Module(
symbol=symbol,
context=ctx,
data_names=('data',),
label_names=('softmax_label',))
num_examples = 60000
epoch_size = int(num_examples / args.batch_size)
optim_params = {
'learning_rate': args.lr,
'momentum': 0.9,
'wd': 0.0005,
'lr_scheduler': mx.lr_scheduler.FactorScheduler(step=10*epoch_size, factor=0.1),
}
mod.fit(train_data=train,
eval_data=val,
eval_metric=mx.metric.Accuracy(),
initializer=mx.init.Xavier(),
optimizer='sgd',
optimizer_params=optim_params,
num_epoch=args.num_epoch,
batch_end_callback=mx.callback.Speedometer(args.batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(args.model_prefix))
评论列表
文章目录