def load_model(version, epoch, patch_size, batch_size=8, ctx=mx.gpu()):
sym, arg, aux = mx.model.load_checkpoint('models/' + version, epoch)
mod = mx.module.Module(sym, context=ctx)
mod.bind(data_shapes=[('data', (batch_size, 20, patch_size, patch_size))],
for_training=False)
mod.set_params(arg, aux)
return mod
predict_utils.py 文件源码
python
阅读 15
收藏 0
点赞 0
评论 0
评论列表
文章目录