mnist_cached.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pyro 作者: uber 项目源码 文件源码
def __init__(self, mode, sup_num, use_cuda=True, *args, **kwargs):
        super(MNISTCached, self).__init__(train=mode in ["sup", "unsup", "valid"], *args, **kwargs)

        # transformations on MNIST data (normalization and one-hot conversion for labels)
        def transform(x):
            return fn_x_mnist(x, use_cuda)

        def target_transform(y):
            return fn_y_mnist(y, use_cuda)

        self.mode = mode

        assert mode in ["sup", "unsup", "test", "valid"], "invalid train/test option values"

        if mode in ["sup", "unsup", "valid"]:

            # transform the training data if transformations are provided
            if transform is not None:
                self.train_data = (transform(self.train_data.float()))
            if target_transform is not None:
                self.train_labels = (target_transform(self.train_labels))

            if MNISTCached.train_data_sup is None:
                if sup_num is None:
                    assert mode == "unsup"
                    MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup = \
                        self.train_data, self.train_labels
                else:
                    MNISTCached.train_data_sup, MNISTCached.train_labels_sup, \
                        MNISTCached.train_data_unsup, MNISTCached.train_labels_unsup, \
                        MNISTCached.data_valid, MNISTCached.labels_valid = \
                        split_sup_unsup_valid(self.train_data, self.train_labels, sup_num)

            if mode == "sup":
                self.train_data, self.train_labels = MNISTCached.train_data_sup, MNISTCached.train_labels_sup
            elif mode == "unsup":
                self.train_data = MNISTCached.train_data_unsup

                # making sure that the unsupervised labels are not available to inference
                self.train_labels = (torch.Tensor(
                    MNISTCached.train_labels_unsup.shape[0]).view(-1, 1)) * np.nan
            else:
                self.train_data, self.train_labels = MNISTCached.data_valid, MNISTCached.labels_valid

        else:
            # transform the testing data if transformations are provided
            if transform is not None:
                self.test_data = (transform(self.test_data.float()))
            if target_transform is not None:
                self.test_labels = (target_transform(self.test_labels))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号