def log_linear_quantize(input, sf, bits):
assert bits >= 1, bits
if bits == 1:
return torch.sign(input), 0.0, 0.0
s = torch.sign(input)
input0 = torch.log(torch.abs(input) + 1e-20)
v = linear_quantize(input0, sf, bits)
v = torch.exp(v) * s
return v
评论列表
文章目录