def test_quantize_from_probs2(size, resolution):
set_random_seed(make_seed(size, resolution))
probs = np.exp(np.random.random(size)).astype(np.float32)
probs2 = probs.reshape((1, size))
quantized = quantize_from_probs2(probs2, resolution)
assert quantized.shape == probs2.shape
assert quantized.dtype == np.int8
assert np.all(quantized.sum(axis=1) == resolution)
# Check that quantized result is closer to target than any other value.
quantized = quantized.reshape((size, ))
target = resolution * probs / probs.sum()
distance = np.abs(quantized - target).sum()
for combo in itertools.combinations(range(size), resolution):
other = np.zeros(size, np.int8)
for i in combo:
other[i] += 1
assert other.sum() == resolution
other_distance = np.abs(other - target).sum()
assert distance <= other_distance
评论列表
文章目录