def get_loader(chunk_list):
data = []
label = []
for f in chunk_list:
print ('Loading data from %s' %f)
with h5py.File(f, 'r') as hf:
data.append(np.asarray(hf['data']))
label.append(np.asarray(hf['label']))
data = torch.FloatTensor(np.concatenate(data, axis = 0))
label = torch.FloatTensor(np.concatenate(label, axis = 0))
print ('Total %d frames loaded' %data.size(0))
dset_train = TensorDataset(data, label)
loader_train = DataLoader(dset_train, batch_size = 256, shuffle = True, num_workers = 10, pin_memory = False)
return loader_train
评论列表
文章目录