test_datasets.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:nnmnkwii 作者: r9y9 项目源码 文件源码
def test_sequence_wise_torch_data_loader():
    import torch
    from torch.utils import data as data_utils

    X, Y = _get_small_datasets(padded=False)

    class TorchDataset(data_utils.Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __getitem__(self, idx):
            return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

        def __len__(self):
            return len(self.X)

    def __test(X, Y, batch_size):
        dataset = TorchDataset(X, Y)
        loader = data_utils.DataLoader(
            dataset, batch_size=batch_size, num_workers=1, shuffle=True)
        for idx, (x, y) in enumerate(loader):
            assert len(x.shape) == len(y.shape)
            assert len(x.shape) == 3
            print(idx, x.shape, y.shape)

    # Test with batch_size = 1
    yield __test, X, Y, 1
    # Since we have variable length frames, batch size larger than 1 causes
    # runtime error.
    yield raises(RuntimeError)(__test), X, Y, 2

    # For padded dataset, which can be reprensented by (N, T^max, D), batchsize
    # can be any number.
    X, Y = _get_small_datasets(padded=True)
    yield __test, X, Y, 1
    yield __test, X, Y, 2
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号