data_process.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:pytorch_crowd_count 作者: BingzheWu 项目源码 文件源码
def gen_train_data(dataset_paths):
    X_fs = []
    Y_fs = []

    for path in dataset_paths:
        images, gts, densities = load_images_and_gts(path)
        X_fs += images
        Y_fs += densities
    from sklearn.model_selection import train_test_split
    X_fs_train, X_fs_test, Y_fs_train, Y_fs_test = train_test_split(X_fs, Y_fs, test_size = 0.2)
    X_train, Y_train = X_fs_train, Y_fs_train
    X_test, Y_test = X_fs_test, Y_fs_test
    print(len(X_train))
    X_train, Y_train = multiscale_pyramidal(X_train, Y_train)
    #X_train, Y_train = adapt_images_and_densities(X_train, Y_train, slice_w, slice_h)
    print(len(X_train))
    X_train, Y_train = generate_slices(X_train, Y_train, slice_w = patch_w, slice_h = patch_h, offset = 8)
    print(len(X_train))
    #X_train, Y_train = crop_slices(X_train, Y_train)
    X_train, Y_train = flip_slices(X_train, Y_train)
    print(len(X_train))
    X_train, Y_train = samples_distribution(X_train,Y_train)
    print(len(X_train))
    X_train,Y_train = shuffle_slices(X_train, Y_train)
    return X_train, Y_train
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号