def testTensorDataset(self):
# dict input
data = {
# 'input': torch.arange(0,8),
'input': np.arange(0, 8),
'target': np.arange(0, 8),
}
d = dataset.TensorDataset(data)
self.assertEqual(len(d), 8)
self.assertEqual(d[2], {'input': 2, 'target': 2})
# tensor input
a = torch.randn(8)
d = dataset.TensorDataset(a)
self.assertEqual(len(a), len(d))
self.assertEqual(a[1], d[1])
# list of tensors input
d = dataset.TensorDataset([a])
self.assertEqual(len(a), len(d))
self.assertEqual(a[1], d[1][0])
评论列表
文章目录