def form_torch_mixture_dataset(MSabs, MSphase,
SPCS1abs, SPCS2abs,
wavfls1, wavfls2,
lens1, lens2,
arguments):
MSabs = torch.from_numpy(np.array(MSabs))
MSphase = torch.from_numpy(np.array(MSphase))
SPCS1abs = torch.from_numpy(np.array(SPCS1abs))
SPCS2abs = torch.from_numpy(np.array(SPCS2abs))
wavfls1 = torch.from_numpy(np.array(wavfls1))
wavfls2 = torch.from_numpy(np.array(wavfls2))
dataset = MixtureDataset(MSabs, MSphase, SPCS1abs, SPCS2abs,
wavfls1, wavfls2, lens1, lens2)
kwargs = {'num_workers': 1, 'pin_memory': True} if arguments.cuda else {}
loader = data_utils.DataLoader(dataset, batch_size=arguments.batch_size, shuffle=False, **kwargs)
return loader
评论列表
文章目录