def test():
import torchvision
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
dataset = ListDataset(root='/mnt/hgfs/D/download/PASCAL_VOC/voc_all_images',
list_file='./data/voc12_train.txt', train=True, transform=transform, input_size=600)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn)
for images, loc_targets, cls_targets in dataloader:
print(images.size())
print(loc_targets.size())
print(cls_targets.size())
grid = torchvision.utils.make_grid(images, 1)
torchvision.utils.save_image(grid, 'a.jpg')
break
# test()
评论列表
文章目录