mnist.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:mx-lsoftmax 作者: luoyetx 项目源码 文件源码
def profile():
    ctx = mx.gpu(args.gpu) if args.gpu >=0 else mx.cpu()
    val = mx.io.MNISTIter(
            image='data/t10k-images-idx3-ubyte',
            label='data/t10k-labels-idx1-ubyte',
            input_shape=(1, 28, 28),
            mean_r=128,
            scale=1./128,
            batch_size=args.batch_size)
    symbol = get_symbol()
    mod = mx.mod.Module(
            symbol=symbol,
            context=ctx,
            data_names=('data',),
            label_names=('softmax_label',))
    mod.bind(data_shapes=val.provide_data, label_shapes=val.provide_label, for_training=True)
    mod.init_params(initializer=mx.init.Xavier())

    # run a while
    for nbatch, data_batch in enumerate(val):
        mod.forward_backward(data_batch)

    # profile
    mx.profiler.profiler_set_config(mode='symbolic', filename='profile.json')
    mx.profiler.profiler_set_state('run')
    val.reset()
    for nbatch, data_batch in enumerate(val):
        mod.forward_backward(data_batch)
    mx.profiler.profiler_set_state('stop')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号