def linear_quantize(input, sf, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input) - 1
delta = math.pow(2.0, -sf)
bound = math.pow(2.0, bits-1)
min_val = - bound
max_val = bound - 1
rounded = torch.floor(input / delta + 0.5)
clipped_value = torch.clamp(rounded, min_val, max_val) * delta
return clipped_value
评论列表
文章目录