def validate_gof(N, V, C, M, server, conditional):
# Generate samples.
expected = C**V
num_samples = 1000 * expected
ones = np.ones(V, dtype=np.int8)
if conditional:
cond_data = server.sample(1, ones)[0, :]
else:
cond_data = server.make_zero_row()
samples = server.sample(num_samples, ones, cond_data)
logprobs = server.logprob(samples + cond_data[np.newaxis, :])
counts = {}
probs = {}
for sample, logprob in zip(samples, logprobs):
key = tuple(sample)
if key in counts:
counts[key] += 1
else:
counts[key] = 1
probs[key] = np.exp(logprob)
assert len(counts) == expected
# Check accuracy using Pearson's chi-squared test.
keys = sorted(counts.keys(), key=lambda key: -probs[key])
counts = np.array([counts[k] for k in keys], dtype=np.int32)
probs = np.array([probs[k] for k in keys])
probs /= probs.sum()
# Truncate to avoid low-precision.
truncated = False
valid = (probs * num_samples > 20)
if not valid.all():
T = valid.argmin()
T = max(8, T) # Avoid truncating too much
probs = probs[:T]
counts = counts[:T]
truncated = True
gof = multinomial_goodness_of_fit(
probs, counts, num_samples, plot=True, truncated=truncated)
assert 1e-2 < gof
评论列表
文章目录