def load_param(prefix, epoch, convert=False, ctx=None, process=False):
"""
wrapper for load checkpoint
:param prefix: Prefix of model name.
:param epoch: Epoch number of model we would like to load.
:param convert: reference model should be converted to GPU NDArray first
:param ctx: if convert then ctx must be designated.
:param process: model should drop any test
:return: (arg_params, aux_params)
"""
arg_params, aux_params = load_checkpoint(prefix, epoch)
if convert:
if ctx is None:
ctx = mx.cpu()
arg_params = convert_context(arg_params, ctx)
aux_params = convert_context(aux_params, ctx)
if process:
tests = [k for k in arg_params.keys() if '_test' in k]
for test in tests:
arg_params[test.replace('_test', '')] = arg_params.pop(test)
return arg_params, aux_params
评论列表
文章目录