data_loader.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:deep-spike 作者: electronicvisions 项目源码 文件源码
def load_small_digits(train_prop,n_class):
    '''
    Load the data from the scikit learn dataset

    :param train_prop: proportion of samples in the testing set<
    :param n_class: number of different digits
    :return:
    '''

    # Load the 8 by 8 digit dataset
    data = load_digits(n_class)
    N_images = data.target.size
    N_train = int(N_images * train_prop)
    N_test = N_images - N_train

    x_train = data.data[:N_train,:]
    x_test = data.data[N_train:,:]

    class_train = data.target[:N_train]
    class_test = data.target[N_train:]

    z_train = np.zeros((N_train,n_class))
    z_train[np.arange(N_train),class_train] = 1
    z_test = np.zeros((N_test,n_class))
    z_test[np.arange(N_test),class_test] = 1

    return x_train,x_test,z_train,z_test
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号