def form_mixtures(digit1, digit2, loader, arguments):
dataset1, dataset2 = [], []
for i, (ft, tar) in enumerate(loader):
# digit 1
mask = torch.eq(tar, digit1)
inds = torch.nonzero(mask).squeeze()
ft1 = torch.index_select(ft, dim=0, index=inds)
dataset1.append(ft1)
# digit 2
mask = torch.eq(tar, digit2)
inds = torch.nonzero(mask).squeeze()
ft2 = torch.index_select(ft, dim=0, index=inds)
dataset2.append(ft2)
print(i)
dataset1 = torch.cat(dataset1, dim=0)
dataset2 = torch.cat(dataset2, dim=0)
if arguments.input_type == 'noise':
inp1 = torch.randn(dataset1.size(0), arguments.L1)
inp2 = torch.randn(dataset2.size(0), arguments.L1)
elif arguments.input_type == 'autoenc':
inp1 = dataset1
inp2 = dataset2
else:
raise ValueError('Whaaaaaat input_type?')
N1, N2 = dataset1.size(0), dataset2.size(0)
Nmix = min([N1, N2])
dataset_mix = dataset1[:Nmix] + dataset2[:Nmix]
dataset1 = TensorDataset(data_tensor=inp1,
target_tensor=dataset1,
lens=[1]*Nmix)
dataset2 = data_utils.TensorDataset(data_tensor=inp2,
target_tensor=dataset2)
dataset_mix = data_utils.TensorDataset(data_tensor=dataset_mix,
target_tensor=torch.ones(Nmix))
kwargs = {'num_workers': 1, 'pin_memory': True} if arguments.cuda else {}
loader1 = data_utils.DataLoader(dataset1, batch_size=arguments.batch_size, shuffle=False, **kwargs)
loader2 = data_utils.DataLoader(dataset2, batch_size=arguments.batch_size, shuffle=False, **kwargs)
loader_mix = data_utils.DataLoader(dataset_mix, batch_size=arguments.batch_size, shuffle=False, **kwargs)
return loader1, loader2, loader_mix
评论列表
文章目录