cifar_dataset.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:pathnet-pytorch 作者: kimhc6028 项目源码 文件源码
def convert2tensor(self, dataset, batch_size, limit):
        data = dataset['data']
        data = data[:limit]
        print("normalizing images...")
        data = common.normalize(data)
        print("done")
        target = dataset['labels']
        target = target[:limit]
        target = np.asarray(target)

        tensor_data = torch.from_numpy(data)
        tensor_data = tensor_data.float()
        tensor_target = torch.from_numpy(target)

        loader = data_utils.TensorDataset(tensor_data, tensor_target)
        loader_dataset = data_utils.DataLoader(loader, batch_size=batch_size, shuffle = True)
        return loader_dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号