serving_test.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:treecat 作者: posterior 项目源码 文件源码
def validate_gof(N, V, C, M, server, conditional):
    # Generate samples.
    expected = C**V
    num_samples = 1000 * expected
    ones = np.ones(V, dtype=np.int8)
    if conditional:
        cond_data = server.sample(1, ones)[0, :]
    else:
        cond_data = server.make_zero_row()
    samples = server.sample(num_samples, ones, cond_data)
    logprobs = server.logprob(samples + cond_data[np.newaxis, :])
    counts = {}
    probs = {}
    for sample, logprob in zip(samples, logprobs):
        key = tuple(sample)
        if key in counts:
            counts[key] += 1
        else:
            counts[key] = 1
            probs[key] = np.exp(logprob)
    assert len(counts) == expected

    # Check accuracy using Pearson's chi-squared test.
    keys = sorted(counts.keys(), key=lambda key: -probs[key])
    counts = np.array([counts[k] for k in keys], dtype=np.int32)
    probs = np.array([probs[k] for k in keys])
    probs /= probs.sum()

    # Truncate to avoid low-precision.
    truncated = False
    valid = (probs * num_samples > 20)
    if not valid.all():
        T = valid.argmin()
        T = max(8, T)  # Avoid truncating too much
        probs = probs[:T]
        counts = counts[:T]
        truncated = True

    gof = multinomial_goodness_of_fit(
        probs, counts, num_samples, plot=True, truncated=truncated)
    assert 1e-2 < gof
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号