def load_cifar100(folder=CIFAR100_DIR, one_hot=True, partitions=None, filters=None, maps=None):
path = folder + "/cifar-100.pickle"
with open(path, "rb") as input_file:
X, target_ID_fine, target_ID_coarse, fine_ID_corr, coarse_ID_corr, files = cpickle.load(input_file)
X = np.array(X);
target_ID_fine = target_ID_fine[:len(X)]
target_ID_coarse = target_ID_coarse[:len(X)]
fine_ID_corr = {v: k for v, k in zip(range(len(fine_ID_corr)), fine_ID_corr)}
coarse_ID_corr = {v: k for v, k in zip(range(len(coarse_ID_corr)), coarse_ID_corr)}
fine_label_corr = {v: k for k, v in fine_ID_corr.items()}
coarse_label_corr = {v: k for k, v in coarse_ID_corr.items()}
Y = []
for name_y in target_ID_fine:
Y.append(name_y)
Y = np.array(Y)
if one_hot:
Y = to_one_hot_enc(Y)
superY = []
for name_y in target_ID_coarse:
superY.append(name_y)
superY = np.array(superY)
if one_hot:
superY = to_one_hot_enc(superY)
print(len(X))
print(len(Y))
dataset = Dataset(data=X, target=Y,
info={'dict_name_ID_fine': fine_label_corr, 'dict_name_ID_coarse': coarse_label_corr,
'dict_ID_name_fine': fine_ID_corr, 'dict_ID_name_coarse': coarse_ID_corr},
sample_info=[{'Y_coarse': yc, 'files': f} for yc, f in zip(superY, files)])
if partitions:
res = redivide_data([dataset], partitions, filters=filters, maps=maps, shuffle=True)
res += [None] * (3 - len(res))
return Datasets(train=res[0], validation=res[1], test=res[2])
return dataset
评论列表
文章目录