def min_max_quantize(input, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input) - 1
min_val, max_val = input.min(), input.max()
if isinstance(min_val, Variable):
max_val = float(max_val.data.cpu().numpy()[0])
min_val = float(min_val.data.cpu().numpy()[0])
input_rescale = (input - min_val) / (max_val - min_val)
n = math.pow(2.0, bits) - 1
v = torch.floor(input_rescale * n + 0.5) / n
v = v * (max_val - min_val) + min_val
return v
评论列表
文章目录