def profile():
ctx = mx.gpu(args.gpu) if args.gpu >=0 else mx.cpu()
val = mx.io.MNISTIter(
image='data/t10k-images-idx3-ubyte',
label='data/t10k-labels-idx1-ubyte',
input_shape=(1, 28, 28),
mean_r=128,
scale=1./128,
batch_size=args.batch_size)
symbol = get_symbol()
mod = mx.mod.Module(
symbol=symbol,
context=ctx,
data_names=('data',),
label_names=('softmax_label',))
mod.bind(data_shapes=val.provide_data, label_shapes=val.provide_label, for_training=True)
mod.init_params(initializer=mx.init.Xavier())
# run a while
for nbatch, data_batch in enumerate(val):
mod.forward_backward(data_batch)
# profile
mx.profiler.profiler_set_config(mode='symbolic', filename='profile.json')
mx.profiler.profiler_set_state('run')
val.reset()
for nbatch, data_batch in enumerate(val):
mod.forward_backward(data_batch)
mx.profiler.profiler_set_state('stop')
评论列表
文章目录