test_datasets.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:nnmnkwii 作者: r9y9 项目源码 文件源码
def test_frame_wise_torch_data_loader():
    import torch
    from torch.utils import data as data_utils

    X, Y = _get_small_datasets(padded=False)

    # Since torch's Dataset (and Chainer, and maybe others) assumes dataset has
    # fixed size length, i.e., implements `__len__` method, we need to know
    # number of frames for each utterance.
    # Sum of the number of frames is the dataset size for frame-wise iteration.
    lengths = np.array([len(x) for x in X], dtype=np.int)

    # For the above reason, we need to explicitly give the number of frames.
    X = MemoryCacheFramewiseDataset(X, lengths, cache_size=len(X))
    Y = MemoryCacheFramewiseDataset(Y, lengths, cache_size=len(Y))

    class TorchDataset(data_utils.Dataset):
        def __init__(self, X, Y):
            self.X = X
            self.Y = Y

        def __getitem__(self, idx):
            return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])

        def __len__(self):
            return len(self.X)

    def __test(X, Y, batch_size):
        dataset = TorchDataset(X, Y)
        loader = data_utils.DataLoader(
            dataset, batch_size=batch_size, num_workers=1, shuffle=True)
        for idx, (x, y) in enumerate(loader):
            assert len(x.shape) == 2
            assert len(y.shape) == 2

    yield __test, X, Y, 128
    yield __test, X, Y, 256
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号