def load_cifar10(folder=CIFAR10_DIR, one_hot=True, partitions=None, filters=None, maps=None, balance_classes=False):
path = folder + "/cifar-10.pickle"
with open(path, "rb") as input_file:
X, target_name, files = cpickle.load(input_file)
X = np.array(X)
dict_name_ID = {}
i = 0
list_of_targets = sorted(list(set(target_name)))
for k in list_of_targets:
dict_name_ID[k] = i
i += 1
dict_ID_name = {v: k for k, v in dict_name_ID.items()}
Y = []
for name_y in target_name:
Y.append(dict_name_ID[name_y])
if one_hot:
Y = to_one_hot_enc(Y)
dataset = Dataset(data=X, target=Y, info={'dict_name_ID': dict_name_ID, 'dict_ID_name': dict_ID_name},
sample_info=[{'target_name': t, 'files': f} for t, f in zip(target_name, files)])
if partitions:
res = redivide_data([dataset], partitions, filters=filters, maps=maps, shuffle=True, balance_classes=True)
res += [None] * (3 - len(res))
return Datasets(train=res[0], validation=res[1], test=res[2])
return dataset
评论列表
文章目录