def test__batch_generator(self, dataset_provider, mocker):
mocker.patch.object(dataset_provider, '_preprocess_batch',
lambda x, _: x)
datum_list = range(10)
generator = dataset_provider._batch_generator(datum_list)
results = [next(generator) for _ in range(4)]
assert [len(x) for x in results] == [4, 4, 2, 4]
assert sorted(sum(results[:-1], [])) == datum_list
datum_list = range(12)
generator = dataset_provider._batch_generator(datum_list)
assert isinstance(generator, GeneratorType)
results = list(islice(generator, 4))
assert [len(x) for x in results] == [4, 4, 4, 4]
assert sorted(sum(results[:-1], [])) == datum_list
dataset_providers_test.py 文件源码
python
阅读 24
收藏 0
点赞 0
评论 0
评论列表
文章目录