def testBatchDataset(self):
if hasattr(torch, "arange"):
t = torch.arange(0, 16).long()
else:
t = torch.range(0, 15).long()
batchsize = 8
d = dataset.ListDataset(t, lambda x: {'input': x})
d = dataset.BatchDataset(d, batchsize)
ex = d[0]['input']
self.assertEqual(len(ex), batchsize)
self.assertEqual(ex[-1], batchsize - 1)
# def testTransformDataset(self):
# d = dataset.TransformDataset(dataset.TensorDataset()
评论列表
文章目录