def _test_batch_sampler(self, **kwargs):
# [(0, 1), (2, 3, 4), (5, 6), (7, 8, 9), ...]
batches = []
for i in range(0, 100, 5):
batches.append(tuple(range(i, i + 2)))
batches.append(tuple(range(i + 2, i + 5)))
dl = DataLoader(self.dataset, batch_sampler=batches, **kwargs)
self.assertEqual(len(dl), 40)
for i, (input, _target) in enumerate(dl):
if i % 2 == 0:
offset = i * 5 // 2
self.assertEqual(len(input), 2)
self.assertEqual(input, self.data[offset:offset + 2])
else:
offset = i * 5 // 2
self.assertEqual(len(input), 3)
self.assertEqual(input, self.data[offset:offset + 3])
评论列表
文章目录