test_datasets.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:tnt 作者: pytorch 项目源码 文件源码
def testTensorDataset(self):
        # dict input
        data = {
            # 'input': torch.arange(0,8),
            'input': np.arange(0, 8),
            'target': np.arange(0, 8),
        }
        d = dataset.TensorDataset(data)
        self.assertEqual(len(d), 8)
        self.assertEqual(d[2], {'input': 2, 'target': 2})

        # tensor input
        a = torch.randn(8)
        d = dataset.TensorDataset(a)
        self.assertEqual(len(a), len(d))
        self.assertEqual(a[1], d[1])

        # list of tensors input
        d = dataset.TensorDataset([a])
        self.assertEqual(len(a), len(d))
        self.assertEqual(a[1], d[1][0])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号