inspect_weight_dist.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:additions_mxnet 作者: eldercrow 项目源码 文件源码
def inspect_weight_dist(prefix_net, epoch):
    #
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix_net, epoch)

    quantize_bit = 5

    err_log = {}
    err_uni = {}

    err_diff = []

    for k in sorted(arg_params):
        if not k.endswith('_weight'):
            continue
        v = arg_params[k].asnumpy().ravel()

        err_log[k] = measure_log_quantize_error(v, quantize_bit)
        err_uni[k] = measure_uni_quantize_error(v, quantize_bit)

        err_diff.append(err_log[k] - err_uni[k])

    plt.plot(range(len(err_diff)), err_diff)

    import ipdb
    ipdb.set_trace()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号