def test_ensemble_logprob_shape(ensemble):
table = TINY_TABLE
server = EnsembleServer(ensemble)
logprobs = server.logprob(table.data)
N = table.num_rows
assert logprobs.dtype == np.float32
assert logprobs.shape == (N, )
assert np.isfinite(logprobs).all()
评论列表
文章目录