def test_batch_single(): for batch_size in range(1, 10): batched = tuple(main.batch([1], batch_size=batch_size)) assert batched == ((1,),)