svhn_dataset.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:pathnet-pytorch 作者: kimhc6028 项目源码 文件源码
def convert2tensor(self, dataset, batch_size, limit):
        b_data = dataset['X']
        b_data = b_data[:limit]
        print("normalizing images...")
        b_data = common.normalize(b_data)
        print("done")
        target = dataset['y']
        target = target.reshape((len(target)))
        target = target[:limit]
        """SVHN dataset is between 1 to 10: shift this to 0 to 9 to fit with neural network"""
        target = target - 1

        data = []
        for i in range(len(target)):
            data.append(b_data[:,:,:,i])
        data = np.asarray(data)
        tensor_data = torch.from_numpy(data)
        tensor_data = tensor_data.float()
        tensor_target = torch.from_numpy(target)

        loader = data_utils.TensorDataset(tensor_data, tensor_target)
        loader_dataset = data_utils.DataLoader(loader, batch_size=batch_size, shuffle = True)
        return loader_dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号