def inspect_weight_dist(prefix_net, epoch):
#
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix_net, epoch)
quantize_bit = 5
err_log = {}
err_uni = {}
err_diff = []
for k in sorted(arg_params):
if not k.endswith('_weight'):
continue
v = arg_params[k].asnumpy().ravel()
err_log[k] = measure_log_quantize_error(v, quantize_bit)
err_uni[k] = measure_uni_quantize_error(v, quantize_bit)
err_diff.append(err_log[k] - err_uni[k])
plt.plot(range(len(err_diff)), err_diff)
import ipdb
ipdb.set_trace()
评论列表
文章目录