load_model.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:mx-rfcn 作者: giorking 项目源码 文件源码
def load_param_rcnn(prefix, epoch, convert=False, ctx=None):
    """
    wrapper for load checkpoint
    :param prefix: Prefix of model name.
    :param epoch: Epoch number of model we would like to load.
    :param convert: reference model should be converted to GPU NDArray first
    :param ctx: if convert then ctx must be designated.
    :return: (arg_params, aux_params)
    """
    arg_params, aux_params = load_checkpoint(prefix, epoch)
    num_classes = 1000
    if "bbox_pred_bias" in arg_params.keys():
        num_classes = len(arg_params['bbox_pred_bias'].asnumpy()) / 4
    if config.TRAIN.BBOX_NORMALIZATION_PRECOMPUTED and "bbox_pred_bias" in arg_params.keys():
        print "lode model with mean/std"
        means = np.tile(np.array(config.TRAIN.BBOX_MEANS_INV), (1, num_classes))
        stds = np.tile(np.array(config.TRAIN.BBOX_STDS_INV), (1, num_classes))
        arg_params['bbox_pred_weight'] = (arg_params['bbox_pred_weight'].T * mx.nd.array(stds)).T
        arg_params['bbox_pred_bias'] = (arg_params['bbox_pred_bias'] - mx.nd.array(np.squeeze(means))) * \
                                       mx.nd.array(np.squeeze(stds))

    if convert:
        if ctx is None:
            ctx = mx.cpu()
        arg_params = convert_context(arg_params, ctx)
        aux_params = convert_context(aux_params, ctx)
    return arg_params, aux_params, num_classes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号