datasets.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def load_caltech101(folder=CALTECH101_DIR, one_hot=True, partitions=None, filters=None, maps=None):
    path = folder + "/caltech101.pickle"
    with open(path, "rb") as input_file:
        X, target_name, files = cpickle.load(input_file)
    dict_name_ID = {}
    i = 0
    list_of_targets = sorted(list(set(target_name)))
    for k in list_of_targets:
        dict_name_ID[k] = i
        i += 1
    dict_ID_name = {v: k for k, v in dict_name_ID.items()}
    Y = []
    for name_y in target_name:
        Y.append(dict_name_ID[name_y])
    if one_hot:
        Y = to_one_hot_enc(Y)
    dataset = Dataset(data=X, target=Y, info={'dict_name_ID': dict_name_ID, 'dict_ID_name': dict_ID_name},
                      sample_info=[{'target_name': t, 'files': f} for t, f in zip(target_name, files)])
    if partitions:
        res = redivide_data([dataset], partitions, filters=filters, maps=maps, shuffle=True)
        res += [None] * (3 - len(res))
        return Datasets(train=res[0], validation=res[1], test=res[2])
    return dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号