datasets.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def load_cifar100(folder=CIFAR100_DIR, one_hot=True, partitions=None, filters=None, maps=None):
    path = folder + "/cifar-100.pickle"
    with open(path, "rb") as input_file:
        X, target_ID_fine, target_ID_coarse, fine_ID_corr, coarse_ID_corr, files = cpickle.load(input_file)
    X = np.array(X);

    target_ID_fine = target_ID_fine[:len(X)]
    target_ID_coarse = target_ID_coarse[:len(X)]

    fine_ID_corr = {v: k for v, k in zip(range(len(fine_ID_corr)), fine_ID_corr)}
    coarse_ID_corr = {v: k for v, k in zip(range(len(coarse_ID_corr)), coarse_ID_corr)}
    fine_label_corr = {v: k for k, v in fine_ID_corr.items()}
    coarse_label_corr = {v: k for k, v in coarse_ID_corr.items()}

    Y = []
    for name_y in target_ID_fine:
        Y.append(name_y)
    Y = np.array(Y)
    if one_hot:
        Y = to_one_hot_enc(Y)
    superY = []
    for name_y in target_ID_coarse:
        superY.append(name_y)
    superY = np.array(superY)
    if one_hot:
        superY = to_one_hot_enc(superY)

    print(len(X))
    print(len(Y))
    dataset = Dataset(data=X, target=Y,
                      info={'dict_name_ID_fine': fine_label_corr, 'dict_name_ID_coarse': coarse_label_corr,
                                         'dict_ID_name_fine': fine_ID_corr, 'dict_ID_name_coarse': coarse_ID_corr},
                      sample_info=[{'Y_coarse': yc, 'files': f} for yc, f in zip(superY, files)])
    if partitions:
        res = redivide_data([dataset], partitions, filters=filters, maps=maps, shuffle=True)
        res += [None] * (3 - len(res))
        return Datasets(train=res[0], validation=res[1], test=res[2])
    return dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号