def check_aug():
nfold = 0
tst_dataset = CSVDataset_tst(f'../../_data/fold{nfold}/train.csv')
tst = data.DataLoader(tst_dataset, batch_size=1, shuffle=False, num_workers=8)
for j, val_data in enumerate(tst, 0):
if j == 3:
inputs, labels = val_data
inputs, labels = inputs.numpy()[0], labels.numpy()[0]
print(inputs.shape, labels.shape, np.amax(inputs), np.amin(inputs), np.mean(inputs))
for i in range(13):
plt.subplot(3, 5, 1 + i)
plt.imshow(np.transpose(inputs[i], (1, 2, 0)))
break
plt.show()
评论列表
文章目录