def __init__(self, mode, sup_num, use_cuda=True, *args, **kwargs):
super(MNISTCached, self).__init__(train=mode in ["sup", "unsup", "valid"], *args, **kwargs)
# transformations on MNIST data (normalization and one-hot conversion for labels)
def transform(x):
return fn_x_mnist(x, use_cuda)
def target_transform(y):
return fn_y_mnist(y, use_cuda)
self.mode = mode
assert mode in ["sup", "unsup", "test", "valid"], "invalid train/test option values"
if mode in ["sup", "unsup", "valid"]:
# transform the training data if transformations are provided
if transform is not None:
self.train_data = (transform(self.train_data.float()))
if target_transform is not None:
self.train_labels = (target_transform(self.train_labels))
if MNISTCached.train_data_sup is None:
if sup_num is None:
assert mode == "unsup"
MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup = \
self.train_data, self.train_labels
else:
MNISTCached.train_data_sup, MNISTCached.train_labels_sup, \
MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup, \
MNISTCached.data_valid, MNISTCached.labels_valid = \
split_sup_unsup_valid(self.train_data, self.train_labels, sup_num)
if mode == "sup":
self.train_data, self.train_labels = MNISTCached.train_data_sup, MNISTCached.train_labels_sup
elif mode == "unsup":
self.train_data = MNISTCached.train_data_unsup
# making sure that the unsupervised labels are not available to inference
self.train_labels = (torch.Tensor(
MNISTCached.train_labels_unsup.shape[0]).view(-1, 1)) * np.nan
else:
self.train_data, self.train_labels = MNISTCached.data_valid, MNISTCached.labels_valid
else:
# transform the testing data if transformations are provided
if transform is not None:
self.test_data = (transform(self.test_data.float()))
if target_transform is not None:
self.test_labels = (target_transform(self.test_labels))
评论列表
文章目录