def test_sequence_wise_torch_data_loader():
import torch
from torch.utils import data as data_utils
X, Y = _get_small_datasets(padded=False)
class TorchDataset(data_utils.Dataset):
def __init__(self, X, Y):
self.X = X
self.Y = Y
def __getitem__(self, idx):
return torch.from_numpy(self.X[idx]), torch.from_numpy(self.Y[idx])
def __len__(self):
return len(self.X)
def __test(X, Y, batch_size):
dataset = TorchDataset(X, Y)
loader = data_utils.DataLoader(
dataset, batch_size=batch_size, num_workers=1, shuffle=True)
for idx, (x, y) in enumerate(loader):
assert len(x.shape) == len(y.shape)
assert len(x.shape) == 3
print(idx, x.shape, y.shape)
# Test with batch_size = 1
yield __test, X, Y, 1
# Since we have variable length frames, batch size larger than 1 causes
# runtime error.
yield raises(RuntimeError)(__test), X, Y, 2
# For padded dataset, which can be reprensented by (N, T^max, D), batchsize
# can be any number.
X, Y = _get_small_datasets(padded=True)
yield __test, X, Y, 1
yield __test, X, Y, 2
评论列表
文章目录