def test_sequential_batch(self):
loader = DataLoader(self.dataset, batch_size=2, shuffle=False)
batch_size = loader.batch_size
for i, sample in enumerate(loader):
idx = i * batch_size
self.assertEqual(set(sample.keys()), {'a_tensor', 'another_dict'})
self.assertEqual(set(sample['another_dict'].keys()), {'a_number'})
t = sample['a_tensor']
self.assertEqual(t.size(), torch.Size([batch_size, 4, 2]))
self.assertTrue((t[0] == idx).all())
self.assertTrue((t[1] == idx + 1).all())
n = sample['another_dict']['a_number']
self.assertEqual(n.size(), torch.Size([batch_size]))
self.assertEqual(n[0], idx)
self.assertEqual(n[1], idx + 1)
评论列表
文章目录