def load(cls, checkpoint_path, ctx=mx.cpu(), **kwargs):
"""
loads model from checkpoint_path.
"""
with open("{}.json".format(checkpoint_path)) as f:
model_args = json.load(f)
model = cls(**model_args, **kwargs)
model.load_params(checkpoint_path, ctx)
logger.info("model loaded: %s.", checkpoint_path)
return model
评论列表
文章目录