def test_batch_predict():
"""Test the batch prediction feed dict generator."""
X = np.arange(100)
fd = {'X': X}
data = ab.batch_prediction(fd, batch_size=10)
# Make sure this is a generator
assert isinstance(data, GeneratorType)
# Make sure we get a dict back of a length we expect with correct indices
for ind, d in data:
assert isinstance(d, dict)
assert 'X' in d
assert len(d['X']) == 10
assert all(X[ind] == d['X'])
评论列表
文章目录