def quantize_from_probs2(probs, resolution):
"""Quantize multiple non-normalized probs to given resolution.
Args:
probs: An [N, M]-shaped numpy array of non-normalized probabilities.
Returns:
An [N, M]-shaped array of quantized probabilities such that
np.all(result.sum(axis=1) == resolution).
"""
assert len(probs.shape) == 2
N, M = probs.shape
probs = probs / probs.sum(axis=1, keepdims=True)
result = np.zeros(probs.shape, np.int8)
range_N = np.arange(N, dtype=np.int32)
for _ in range(resolution):
sample = probs.argmax(axis=1)
result[range_N, sample] += 1
probs[range_N, sample] -= 1.0 / resolution
return result
评论列表
文章目录