def get_mean_and_std(dataset, max_load=10000):
'''Compute the mean and std value of dataset.'''
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
N = min(max_load, len(dataset))
for i in range(N):
print(i)
im,_,_ = dataset.load(1)
for j in range(3):
mean[j] += im[:,j,:,:].mean()
std[j] += im[:,j,:,:].std()
mean.div_(N)
std.div_(N)
return mean, std
评论列表
文章目录